diff --git a/.circleci/config.yml b/.circleci/config.yml
index 43f2d58acdf31..86636d6d20314 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -2,7 +2,7 @@ version: 2
defaults: &defaults
docker:
- - image: palantirtechnologies/circle-spark-base:0.1.0
+ - image: palantirtechnologies/circle-spark-base:0.1.3
resource_class: xlarge
environment: &defaults-environment
TERM: dumb
@@ -128,7 +128,7 @@ jobs:
<<: *defaults
# Some part of the maven setup fails if there's no R, so we need to use the R image here
docker:
- - image: palantirtechnologies/circle-spark-r:0.1.0
+ - image: palantirtechnologies/circle-spark-r:0.1.3
steps:
# Saves us from recompiling every time...
- restore_cache:
@@ -147,12 +147,7 @@ jobs:
keys:
- build-binaries-{{ checksum "build/mvn" }}-{{ checksum "build/sbt" }}
- build-binaries-
- - run: |
- ./build/mvn -T1C -DskipTests -Phadoop-cloud -Phadoop-palantir -Pkinesis-asl -Pkubernetes -Pyarn -Psparkr install \
- | tee -a "/tmp/mvn-install.log"
- - store_artifacts:
- path: /tmp/mvn-install.log
- destination: mvn-install.log
+ - run: ./build/mvn -DskipTests -Phadoop-cloud -Phadoop-palantir -Pkinesis-asl -Pkubernetes -Pyarn -Psparkr install
# Get sbt to run trivially, ensures its launcher is downloaded under build/
- run: ./build/sbt -h || true
- save_cache:
@@ -300,7 +295,7 @@ jobs:
# depends on build-sbt, but we only need the assembly jars
<<: *defaults
docker:
- - image: palantirtechnologies/circle-spark-python:0.1.0
+ - image: palantirtechnologies/circle-spark-python:0.1.3
parallelism: 2
steps:
- *checkout-code
@@ -325,7 +320,7 @@ jobs:
# depends on build-sbt, but we only need the assembly jars
<<: *defaults
docker:
- - image: palantirtechnologies/circle-spark-r:0.1.0
+ - image: palantirtechnologies/circle-spark-r:0.1.3
steps:
- *checkout-code
- attach_workspace:
@@ -438,7 +433,7 @@ jobs:
<<: *defaults
# Some part of the maven setup fails if there's no R, so we need to use the R image here
docker:
- - image: palantirtechnologies/circle-spark-r:0.1.0
+ - image: palantirtechnologies/circle-spark-r:0.1.3
steps:
- *checkout-code
- restore_cache:
@@ -458,7 +453,7 @@ jobs:
deploy-gradle:
<<: *defaults
docker:
- - image: palantirtechnologies/circle-spark-r:0.1.0
+ - image: palantirtechnologies/circle-spark-r:0.1.3
steps:
- *checkout-code
- *restore-gradle-wrapper-cache
@@ -470,7 +465,7 @@ jobs:
<<: *defaults
# Some part of the maven setup fails if there's no R, so we need to use the R image here
docker:
- - image: palantirtechnologies/circle-spark-r:0.1.0
+ - image: palantirtechnologies/circle-spark-r:0.1.3
steps:
# This cache contains the whole project after version was set and mvn package was called
# Restoring first (and instead of checkout) as mvn versions:set mutates real source code...
diff --git a/.gitignore b/.gitignore
index 61f9349f42a76..7114705bfccf4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -81,7 +81,6 @@ work/
.credentials
dev/pr-deps
docs/.jekyll-metadata
-*.crc
# For Hive
TempStatsStore/
diff --git a/FORK.md b/FORK.md
index 7efc0e7961237..d1262b890b67d 100644
--- a/FORK.md
+++ b/FORK.md
@@ -29,3 +29,5 @@
# Reverted
* [SPARK-25908](https://issues.apache.org/jira/browse/SPARK-25908) - Removal of `monotonicall_increasing_id`, `toDegree`, `toRadians`, `approxCountDistinct`, `unionAll`
* [SPARK-25862](https://issues.apache.org/jira/browse/SPARK-25862) - Removal of `unboundedPreceding`, `unboundedFollowing`, `currentRow`
+* [SPARK-26127](https://issues.apache.org/jira/browse/SPARK-26127) - Removal of deprecated setters from tree regression and classification models
+* [SPARK-25867](https://issues.apache.org/jira/browse/SPARK-25867) - Removal of KMeans computeCost
diff --git a/R/WINDOWS.md b/R/WINDOWS.md
index da668a69b8679..33a4c850cfdac 100644
--- a/R/WINDOWS.md
+++ b/R/WINDOWS.md
@@ -3,7 +3,7 @@
To build SparkR on Windows, the following steps are required
1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to
-include Rtools and R in `PATH`.
+include Rtools and R in `PATH`. Note that support for R prior to version 3.4 is deprecated as of Spark 3.0.0.
2. Install
[JDK8](http://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html) and set
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 70f22bf895495..74cdbd185e570 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -15,7 +15,7 @@ URL: http://www.apache.org/ http://spark.apache.org/
BugReports: http://spark.apache.org/contributing.html
SystemRequirements: Java (== 8)
Depends:
- R (>= 3.0),
+ R (>= 3.1),
methods
Suggests:
knitr,
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index fb158e17fc19c..ad9cd845f696c 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -767,6 +767,14 @@ setMethod("repartition",
#' using \code{spark.sql.shuffle.partitions} as number of partitions.}
#'}
#'
+#' At least one partition-by expression must be specified.
+#' When no explicit sort order is specified, "ascending nulls first" is assumed.
+#'
+#' Note that due to performance reasons this method uses sampling to estimate the ranges.
+#' Hence, the output may not be consistent, since sampling can return different values.
+#' The sample size can be controlled by the config
+#' \code{spark.sql.execution.rangeExchange.sampleSizePerPartition}.
+#'
#' @param x a SparkDataFrame.
#' @param numPartitions the number of partitions to use.
#' @param col the column by which the range partitioning will be performed.
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 9abb7fc1fadb4..f72645a257796 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -3370,7 +3370,7 @@ setMethod("flatten",
#'
#' @rdname column_collection_functions
#' @aliases map_entries map_entries,Column-method
-#' @note map_entries since 2.4.0
+#' @note map_entries since 3.0.0
setMethod("map_entries",
signature(x = "Column"),
function(x) {
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
index 497f18c763048..7252351ebebb2 100644
--- a/R/pkg/R/stats.R
+++ b/R/pkg/R/stats.R
@@ -109,7 +109,7 @@ setMethod("corr",
#'
#' Finding frequent items for columns, possibly with false positives.
#' Using the frequent element count algorithm described in
-#' \url{http://dx.doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou.
+#' \url{https://doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou.
#'
#' @param x A SparkDataFrame.
#' @param cols A vector column names to search frequent items in.
@@ -143,7 +143,7 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"),
#' *exact* rank of x is close to (p * N). More precisely,
#' floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
#' This method implements a variation of the Greenwald-Khanna algorithm (with some speed
-#' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670
+#' optimizations). The algorithm was first present in [[https://doi.org/10.1145/375663.375670
#' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna.
#' Note that NA values will be ignored in numerical columns before calculation. For
#' columns only containing NA values, an empty list is returned.
diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R
index 8c75c19ca7ac3..3efb460846fc2 100644
--- a/R/pkg/inst/profile/general.R
+++ b/R/pkg/inst/profile/general.R
@@ -16,6 +16,10 @@
#
.First <- function() {
+ if (utils::compareVersion(paste0(R.version$major, ".", R.version$minor), "3.4.0") == -1) {
+ warning("Support for R prior to version 3.4 is deprecated since Spark 3.0.0")
+ }
+
packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR")
dirs <- strsplit(packageDir, ",")[[1]]
.libPaths(c(dirs, .libPaths()))
diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R
index 8a8111a8c5419..32eb3671b5941 100644
--- a/R/pkg/inst/profile/shell.R
+++ b/R/pkg/inst/profile/shell.R
@@ -16,6 +16,10 @@
#
.First <- function() {
+ if (utils::compareVersion(paste0(R.version$major, ".", R.version$minor), "3.4.0") == -1) {
+ warning("Support for R prior to version 3.4 is deprecated since Spark 3.0.0")
+ }
+
home <- Sys.getenv("SPARK_HOME")
.libPaths(c(file.path(home, "R", "lib"), .libPaths()))
Sys.setenv(NOAWT = 1)
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index 4cd7fc857a2b2..77a29c9ecad86 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -1674,7 +1674,7 @@ test_that("column functions", {
# check for unparseable
df <- as.DataFrame(list(list("a" = "")))
- expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]], NA)
+ expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]]$a, NA)
# check if array type in string is correctly supported.
jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]"
diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R
index bfb1a046490ec..6f0d2aefee886 100644
--- a/R/pkg/tests/fulltests/test_streaming.R
+++ b/R/pkg/tests/fulltests/test_streaming.R
@@ -127,6 +127,7 @@ test_that("Specify a schema by using a DDL-formatted string when reading", {
expect_false(awaitTermination(q, 5 * 1000))
callJMethod(q@ssq, "processAllAvailable")
expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3)
+ stopQuery(q)
expect_error(read.stream(path = parquetPath, schema = "name stri"),
"DataType stri is not supported.")
diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd
index 7d924efa9d4bb..f80b45b4f36a8 100644
--- a/R/pkg/vignettes/sparkr-vignettes.Rmd
+++ b/R/pkg/vignettes/sparkr-vignettes.Rmd
@@ -57,6 +57,20 @@ First, let's load and attach the package.
library(SparkR)
```
+```{r, include=FALSE}
+# disable eval if java version not supported
+override_eval <- tryCatch(!is.numeric(SparkR:::checkJavaVersion()),
+ error = function(e) { TRUE },
+ warning = function(e) { TRUE })
+
+if (override_eval) {
+ opts_hooks$set(eval = function(options) {
+ options$eval = FALSE
+ options
+ })
+}
+```
+
`SparkSession` is the entry point into SparkR which connects your R program to a Spark cluster. You can create a `SparkSession` using `sparkR.session` and pass in options such as the application name, any Spark packages depended on, etc.
We use default settings in which it runs in local mode. It auto downloads Spark package in the background if no previous installation is found. For more details about setup, see [Spark Session](#SetupSparkSession).
diff --git a/assembly/README b/assembly/README
index d5dafab477410..1fd6d8858348c 100644
--- a/assembly/README
+++ b/assembly/README
@@ -9,4 +9,4 @@ This module is off by default. To activate it specify the profile in the command
If you need to build an assembly for a different version of Hadoop the
hadoop-version system property needs to be set as in this example:
- -Dhadoop.version=2.7.3
+ -Dhadoop.version=2.7.4
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 5cda79564104c..6f8153e847a47 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -20,12 +20,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../pom.xml
- spark-assembly_2.11
+ spark-assembly_2.12Spark Project Assemblyhttp://spark.apache.org/pom
@@ -76,7 +76,7 @@
org.apache.spark
- spark-avro_2.11
+ spark-avro_${scala.binary.version}${project.version}
diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh
index aa5d847f4be2f..9f735f1148da4 100755
--- a/bin/docker-image-tool.sh
+++ b/bin/docker-image-tool.sh
@@ -29,6 +29,20 @@ if [ -z "${SPARK_HOME}" ]; then
fi
. "${SPARK_HOME}/bin/load-spark-env.sh"
+CTX_DIR="$SPARK_HOME/target/tmp/docker"
+
+function is_dev_build {
+ [ ! -f "$SPARK_HOME/RELEASE" ]
+}
+
+function cleanup_ctx_dir {
+ if is_dev_build; then
+ rm -rf "$CTX_DIR"
+ fi
+}
+
+trap cleanup_ctx_dir EXIT
+
function image_ref {
local image="$1"
local add_repo="${2:-1}"
@@ -41,94 +55,136 @@ function image_ref {
echo "$image"
}
+function docker_push {
+ local image_name="$1"
+ if [ ! -z $(docker images -q "$(image_ref ${image_name})") ]; then
+ docker push "$(image_ref ${image_name})"
+ if [ $? -ne 0 ]; then
+ error "Failed to push $image_name Docker image."
+ fi
+ else
+ echo "$(image_ref ${image_name}) image not found. Skipping push for this image."
+ fi
+}
+
+# Create a smaller build context for docker in dev builds to make the build faster. Docker
+# uploads all of the current directory to the daemon, and it can get pretty big with dev
+# builds that contain test log files and other artifacts.
+#
+# Three build contexts are created, one for each image: base, pyspark, and sparkr. For them
+# to have the desired effect, the docker command needs to be executed inside the appropriate
+# context directory.
+#
+# Note: docker does not support symlinks in the build context.
+function create_dev_build_context {(
+ set -e
+ local BASE_CTX="$CTX_DIR/base"
+ mkdir -p "$BASE_CTX/kubernetes"
+ cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \
+ "$BASE_CTX/kubernetes/dockerfiles"
+
+ cp -r "assembly/target/scala-$SPARK_SCALA_VERSION/jars" "$BASE_CTX/jars"
+ cp -r "resource-managers/kubernetes/integration-tests/tests" \
+ "$BASE_CTX/kubernetes/tests"
+
+ mkdir "$BASE_CTX/examples"
+ cp -r "examples/src" "$BASE_CTX/examples/src"
+ # Copy just needed examples jars instead of everything.
+ mkdir "$BASE_CTX/examples/jars"
+ for i in examples/target/scala-$SPARK_SCALA_VERSION/jars/*; do
+ if [ ! -f "$BASE_CTX/jars/$(basename $i)" ]; then
+ cp $i "$BASE_CTX/examples/jars"
+ fi
+ done
+
+ for other in bin sbin data; do
+ cp -r "$other" "$BASE_CTX/$other"
+ done
+
+ local PYSPARK_CTX="$CTX_DIR/pyspark"
+ mkdir -p "$PYSPARK_CTX/kubernetes"
+ cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \
+ "$PYSPARK_CTX/kubernetes/dockerfiles"
+ mkdir "$PYSPARK_CTX/python"
+ cp -r "python/lib" "$PYSPARK_CTX/python/lib"
+
+ local R_CTX="$CTX_DIR/sparkr"
+ mkdir -p "$R_CTX/kubernetes"
+ cp -r "resource-managers/kubernetes/docker/src/main/dockerfiles" \
+ "$R_CTX/kubernetes/dockerfiles"
+ cp -r "R" "$R_CTX/R"
+)}
+
+function img_ctx_dir {
+ if is_dev_build; then
+ echo "$CTX_DIR/$1"
+ else
+ echo "$SPARK_HOME"
+ fi
+}
+
function build {
local BUILD_ARGS
- local IMG_PATH
- local JARS
-
- if [ ! -f "$SPARK_HOME/RELEASE" ]; then
- # Set image build arguments accordingly if this is a source repo and not a distribution archive.
- #
- # Note that this will copy all of the example jars directory into the image, and that will
- # contain a lot of duplicated jars with the main Spark directory. In a proper distribution,
- # the examples directory is cleaned up before generating the distribution tarball, so this
- # issue does not occur.
- IMG_PATH=resource-managers/kubernetes/docker/src/main/dockerfiles
- JARS=assembly/target/scala-$SPARK_SCALA_VERSION/jars
- BUILD_ARGS=(
- ${BUILD_PARAMS}
- --build-arg
- img_path=$IMG_PATH
- --build-arg
- spark_jars=$JARS
- --build-arg
- example_jars=examples/target/scala-$SPARK_SCALA_VERSION/jars
- --build-arg
- k8s_tests=resource-managers/kubernetes/integration-tests/tests
- )
- else
- # Not passed as arguments to docker, but used to validate the Spark directory.
- IMG_PATH="kubernetes/dockerfiles"
- JARS=jars
- BUILD_ARGS=(${BUILD_PARAMS})
+ local SPARK_ROOT="$SPARK_HOME"
+
+ if is_dev_build; then
+ create_dev_build_context || error "Failed to create docker build context."
+ SPARK_ROOT="$CTX_DIR/base"
fi
# Verify that the Docker image content directory is present
- if [ ! -d "$IMG_PATH" ]; then
+ if [ ! -d "$SPARK_ROOT/kubernetes/dockerfiles" ]; then
error "Cannot find docker image. This script must be run from a runnable distribution of Apache Spark."
fi
# Verify that Spark has actually been built/is a runnable distribution
# i.e. the Spark JARs that the Docker files will place into the image are present
- local TOTAL_JARS=$(ls $JARS/spark-* | wc -l)
+ local TOTAL_JARS=$(ls $SPARK_ROOT/jars/spark-* | wc -l)
TOTAL_JARS=$(( $TOTAL_JARS ))
if [ "${TOTAL_JARS}" -eq 0 ]; then
error "Cannot find Spark JARs. This script assumes that Apache Spark has first been built locally or this is a runnable distribution."
fi
+ local BUILD_ARGS=(${BUILD_PARAMS})
local BINDING_BUILD_ARGS=(
${BUILD_PARAMS}
--build-arg
base_img=$(image_ref spark)
)
- local BASEDOCKERFILE=${BASEDOCKERFILE:-"$IMG_PATH/spark/Dockerfile"}
- local PYDOCKERFILE=${PYDOCKERFILE:-"$IMG_PATH/spark/bindings/python/Dockerfile"}
- local RDOCKERFILE=${RDOCKERFILE:-"$IMG_PATH/spark/bindings/R/Dockerfile"}
+ local BASEDOCKERFILE=${BASEDOCKERFILE:-"kubernetes/dockerfiles/spark/Dockerfile"}
+ local PYDOCKERFILE=${PYDOCKERFILE:-false}
+ local RDOCKERFILE=${RDOCKERFILE:-false}
- docker build $NOCACHEARG "${BUILD_ARGS[@]}" \
+ (cd $(img_ctx_dir base) && docker build $NOCACHEARG "${BUILD_ARGS[@]}" \
-t $(image_ref spark) \
- -f "$BASEDOCKERFILE" .
+ -f "$BASEDOCKERFILE" .)
if [ $? -ne 0 ]; then
error "Failed to build Spark JVM Docker image, please refer to Docker build output for details."
fi
- docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \
- -t $(image_ref spark-py) \
- -f "$PYDOCKERFILE" .
+ if [ "${PYDOCKERFILE}" != "false" ]; then
+ (cd $(img_ctx_dir pyspark) && docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \
+ -t $(image_ref spark-py) \
+ -f "$PYDOCKERFILE" .)
+ if [ $? -ne 0 ]; then
+ error "Failed to build PySpark Docker image, please refer to Docker build output for details."
+ fi
+ fi
+
+ if [ "${RDOCKERFILE}" != "false" ]; then
+ (cd $(img_ctx_dir sparkr) && docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \
+ -t $(image_ref spark-r) \
+ -f "$RDOCKERFILE" .)
if [ $? -ne 0 ]; then
- error "Failed to build PySpark Docker image, please refer to Docker build output for details."
+ error "Failed to build SparkR Docker image, please refer to Docker build output for details."
fi
- docker build $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" \
- -t $(image_ref spark-r) \
- -f "$RDOCKERFILE" .
- if [ $? -ne 0 ]; then
- error "Failed to build SparkR Docker image, please refer to Docker build output for details."
fi
}
function push {
- docker push "$(image_ref spark)"
- if [ $? -ne 0 ]; then
- error "Failed to push Spark JVM Docker image."
- fi
- docker push "$(image_ref spark-py)"
- if [ $? -ne 0 ]; then
- error "Failed to push PySpark Docker image."
- fi
- docker push "$(image_ref spark-r)"
- if [ $? -ne 0 ]; then
- error "Failed to push SparkR Docker image."
- fi
+ docker_push "spark"
+ docker_push "spark-py"
+ docker_push "spark-r"
}
function usage {
@@ -143,8 +199,10 @@ Commands:
Options:
-f file Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark.
- -p file Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark.
- -R file Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark.
+ -p file (Optional) Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark.
+ Skips building PySpark docker image if not specified.
+ -R file (Optional) Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark.
+ Skips building SparkR docker image if not specified.
-r repo Repository address.
-t tag Tag to apply to the built image, or to identify the image to be pushed.
-m Use minikube's Docker daemon.
@@ -164,6 +222,9 @@ Examples:
- Build image in minikube with tag "testing"
$0 -m -t testing build
+ - Build PySpark docker image
+ $0 -r docker.io/myrepo -t v2.3.0 -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile build
+
- Build and push image with tag "v2.3.0" to docker.io/myrepo
$0 -r docker.io/myrepo -t v2.3.0 build
$0 -r docker.io/myrepo -t v2.3.0 push
diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh
index 0b5006dbd63ac..0ada5d8d0fc1d 100644
--- a/bin/load-spark-env.sh
+++ b/bin/load-spark-env.sh
@@ -26,15 +26,17 @@ if [ -z "${SPARK_HOME}" ]; then
source "$(dirname "$0")"/find-spark-home
fi
+SPARK_ENV_SH="spark-env.sh"
if [ -z "$SPARK_ENV_LOADED" ]; then
export SPARK_ENV_LOADED=1
export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}"/conf}"
- if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then
+ SPARK_ENV_SH="${SPARK_CONF_DIR}/${SPARK_ENV_SH}"
+ if [[ -f "${SPARK_ENV_SH}" ]]; then
# Promote all variable declarations to environment (exported) variables
set -a
- . "${SPARK_CONF_DIR}/spark-env.sh"
+ . ${SPARK_ENV_SH}
set +a
fi
fi
@@ -42,19 +44,22 @@ fi
# Setting SPARK_SCALA_VERSION if not already set.
if [ -z "$SPARK_SCALA_VERSION" ]; then
+ SCALA_VERSION_1=2.12
+ SCALA_VERSION_2=2.11
- ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11"
- ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.12"
-
- if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then
- echo -e "Presence of build for multiple Scala versions detected." 1>&2
- echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION in spark-env.sh.' 1>&2
+ ASSEMBLY_DIR_1="${SPARK_HOME}/assembly/target/scala-${SCALA_VERSION_1}"
+ ASSEMBLY_DIR_2="${SPARK_HOME}/assembly/target/scala-${SCALA_VERSION_2}"
+ ENV_VARIABLE_DOC="https://spark.apache.org/docs/latest/configuration.html#environment-variables"
+ if [[ -d "$ASSEMBLY_DIR_1" && -d "$ASSEMBLY_DIR_2" ]]; then
+ echo "Presence of build for multiple Scala versions detected ($ASSEMBLY_DIR_1 and $ASSEMBLY_DIR_2)." 1>&2
+ echo "Remove one of them or, export SPARK_SCALA_VERSION=$SCALA_VERSION_1 in ${SPARK_ENV_SH}." 1>&2
+ echo "Visit ${ENV_VARIABLE_DOC} for more details about setting environment variables in spark-env.sh." 1>&2
exit 1
fi
- if [ -d "$ASSEMBLY_DIR2" ]; then
- export SPARK_SCALA_VERSION="2.11"
+ if [[ -d "$ASSEMBLY_DIR_1" ]]; then
+ export SPARK_SCALA_VERSION=${SCALA_VERSION_1}
else
- export SPARK_SCALA_VERSION="2.12"
+ export SPARK_SCALA_VERSION=${SCALA_VERSION_2}
fi
fi
diff --git a/build/mvn b/build/mvn
index 5b2b3c8351114..8931f4f9ffabe 100755
--- a/build/mvn
+++ b/build/mvn
@@ -112,7 +112,8 @@ install_zinc() {
# the build/ folder
install_scala() {
# determine the Scala version used in Spark
- local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'`
+ local scala_binary_version=`grep "scala.binary.version" "${_DIR}/../pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'`
+ local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | grep ${scala_binary_version} | head -n1 | awk -F '[<>]' '{print $3}'`
local scala_bin="${_DIR}/scala-${scala_version}/bin/scala"
local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.lightbend.com}
diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml
index 23a0f49206909..f042a12fda3d2 100644
--- a/common/kvstore/pom.xml
+++ b/common/kvstore/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-kvstore_2.11
+ spark-kvstore_2.12jarSpark Project Local DBhttp://spark.apache.org/
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java
index 448ba78bd23d0..55b4754a6c4ec 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java
@@ -195,6 +195,7 @@ public synchronized void close() throws IOException {
* when Scala wrappers are used, this makes sure that, hopefully, the JNI resources held by
* the iterator will eventually be released.
*/
+ @SuppressWarnings("deprecation")
@Override
protected void finalize() throws Throwable {
db.closeIterator(this);
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index 41fcbf0589499..56d01fa0e8b3d 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-network-common_2.11
+ spark-network-common_2.12jarSpark Project Networkinghttp://spark.apache.org/
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
index 7b28a9a969486..a7afbfa8621c8 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -33,7 +33,7 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {
}
@Override
- public Type type() { return Type.ChunkFetchFailure; }
+ public Message.Type type() { return Type.ChunkFetchFailure; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
index 26d063feb5fe3..fe54fcc50dc86 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -32,7 +32,7 @@ public ChunkFetchRequest(StreamChunkId streamChunkId) {
}
@Override
- public Type type() { return Type.ChunkFetchRequest; }
+ public Message.Type type() { return Type.ChunkFetchRequest; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
index 94c2ac9b20e43..d5c9a9b3202fb 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -39,7 +39,7 @@ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
}
@Override
- public Type type() { return Type.ChunkFetchSuccess; }
+ public Message.Type type() { return Type.ChunkFetchSuccess; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
index f7ffb1bd49bb6..1632fb9e03687 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
@@ -34,7 +34,7 @@ public OneWayMessage(ManagedBuffer body) {
}
@Override
- public Type type() { return Type.OneWayMessage; }
+ public Message.Type type() { return Type.OneWayMessage; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
index a76624ef5dc96..61061903de23f 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -31,7 +31,7 @@ public RpcFailure(long requestId, String errorString) {
}
@Override
- public Type type() { return Type.RpcFailure; }
+ public Message.Type type() { return Type.RpcFailure; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
index 2b30920f0598d..cc1bb95d2d566 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -38,7 +38,7 @@ public RpcRequest(long requestId, ManagedBuffer message) {
}
@Override
- public Type type() { return Type.RpcRequest; }
+ public Message.Type type() { return Type.RpcRequest; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
index d73014ecd8506..c03291e9c0b23 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -33,7 +33,7 @@ public RpcResponse(long requestId, ManagedBuffer message) {
}
@Override
- public Type type() { return Type.RpcResponse; }
+ public Message.Type type() { return Type.RpcResponse; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
index 258ef81c6783d..68fcfa7748611 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
@@ -33,7 +33,7 @@ public StreamFailure(String streamId, String error) {
}
@Override
- public Type type() { return Type.StreamFailure; }
+ public Message.Type type() { return Type.StreamFailure; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
index dc183c043ed9a..1b135af752bd8 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
@@ -34,7 +34,7 @@ public StreamRequest(String streamId) {
}
@Override
- public Type type() { return Type.StreamRequest; }
+ public Message.Type type() { return Type.StreamRequest; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
index 50b811604b84b..568108c4fe5e8 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
@@ -40,7 +40,7 @@ public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
}
@Override
- public Type type() { return Type.StreamResponse; }
+ public Message.Type type() { return Type.StreamResponse; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java
index fa1d26e76b852..7d21151e01074 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/UploadStream.java
@@ -52,7 +52,7 @@ private UploadStream(long requestId, ManagedBuffer meta, long bodyByteCount) {
}
@Override
- public Type type() { return Type.UploadStream; }
+ public Message.Type type() { return Type.UploadStream; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
index 7331c2b481fb1..1b03300d948e2 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -23,6 +23,7 @@
import org.apache.spark.network.buffer.NettyManagedBuffer;
import org.apache.spark.network.protocol.Encoders;
import org.apache.spark.network.protocol.AbstractMessage;
+import org.apache.spark.network.protocol.Message;
/**
* Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
@@ -46,7 +47,7 @@ class SaslMessage extends AbstractMessage {
}
@Override
- public Type type() { return Type.User; }
+ public Message.Type type() { return Type.User; }
@Override
public int encodedLength() {
diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index 1f4d75c7e2ec5..1c0aa4da27ff9 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -371,23 +371,33 @@ private void assertErrorsContain(Set errors, Set contains) {
private void assertErrorAndClosed(RpcResult result, String expectedError) {
assertTrue("unexpected success: " + result.successMessages, result.successMessages.isEmpty());
- // we expect 1 additional error, which contains *either* "closed" or "Connection reset"
Set errors = result.errorMessages;
assertEquals("Expected 2 errors, got " + errors.size() + "errors: " +
errors, 2, errors.size());
+ // We expect 1 additional error due to closed connection and here are possible keywords in the
+ // error message.
+ Set possibleClosedErrors = Sets.newHashSet(
+ "closed",
+ "Connection reset",
+ "java.nio.channels.ClosedChannelException",
+ "java.io.IOException: Broken pipe"
+ );
Set containsAndClosed = Sets.newHashSet(expectedError);
- containsAndClosed.add("closed");
- containsAndClosed.add("Connection reset");
+ containsAndClosed.addAll(possibleClosedErrors);
Pair, Set> r = checkErrorsContain(errors, containsAndClosed);
- Set errorsNotFound = r.getRight();
- assertEquals(1, errorsNotFound.size());
- String err = errorsNotFound.iterator().next();
- assertTrue(err.equals("closed") || err.equals("Connection reset"));
+ assertTrue("Got a non-empty set " + r.getLeft(), r.getLeft().isEmpty());
- assertTrue(r.getLeft().isEmpty());
+ Set errorsNotFound = r.getRight();
+ assertEquals(
+ "The size of " + errorsNotFound + " was not " + (possibleClosedErrors.size() - 1),
+ possibleClosedErrors.size() - 1,
+ errorsNotFound.size());
+ for (String err: errorsNotFound) {
+ assertTrue("Found a wrong error " + err, containsAndClosed.contains(err));
+ }
}
private Pair, Set> checkErrorsContain(
diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml
index ff717057bb25d..a6d99813a8501 100644
--- a/common/network-shuffle/pom.xml
+++ b/common/network-shuffle/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-network-shuffle_2.11
+ spark-network-shuffle_2.12jarSpark Project Shuffle Streaming Servicehttp://spark.apache.org/
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
index f309dda8afca6..6bf3da94030d4 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
@@ -101,7 +101,7 @@ void createAndStart(String[] blockIds, BlockFetchingListener listener)
public RetryingBlockFetcher(
TransportConf conf,
- BlockFetchStarter fetchStarter,
+ RetryingBlockFetcher.BlockFetchStarter fetchStarter,
String[] blockIds,
BlockFetchingListener listener) {
this.fetchStarter = fetchStarter;
diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml
index a1cf761d12d8b..55cdc3140aa08 100644
--- a/common/network-yarn/pom.xml
+++ b/common/network-yarn/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-network-yarn_2.11
+ spark-network-yarn_2.12jarSpark Project YARN Shuffle Servicehttp://spark.apache.org/
diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml
index adbbcb1cb3040..3c3c0d2d96a1c 100644
--- a/common/sketch/pom.xml
+++ b/common/sketch/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-sketch_2.11
+ spark-sketch_2.12jarSpark Project Sketchhttp://spark.apache.org/
diff --git a/common/tags/pom.xml b/common/tags/pom.xml
index f6627beabe84b..883b73a69c9de 100644
--- a/common/tags/pom.xml
+++ b/common/tags/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-tags_2.11
+ spark-tags_2.12jarSpark Project Tagshttp://spark.apache.org/
diff --git a/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java b/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java
index 0ecef6db0e039..890f2faca28b0 100644
--- a/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java
+++ b/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java
@@ -29,6 +29,7 @@
* of the known issue that Scaladoc displays only either the annotation or the comment, whichever
* comes first.
*/
+@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java b/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java
new file mode 100644
index 0000000000000..87e8948f204ff
--- /dev/null
+++ b/common/tags/src/main/java/org/apache/spark/annotation/Evolving.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.annotation;
+
+import java.lang.annotation.*;
+
+/**
+ * APIs that are meant to evolve towards becoming stable APIs, but are not stable APIs yet.
+ * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2).
+ */
+@Documented
+@Retention(RetentionPolicy.RUNTIME)
+@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
+ ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
+public @interface Evolving {}
diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java b/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java
index ff8120291455f..96875920cd9c3 100644
--- a/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java
+++ b/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java
@@ -30,6 +30,7 @@
* of the known issue that Scaladoc displays only either the annotation or the comment, whichever
* comes first.
*/
+@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
diff --git a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java b/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java
deleted file mode 100644
index 323098f69c6e1..0000000000000
--- a/common/tags/src/main/java/org/apache/spark/annotation/InterfaceStability.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.annotation;
-
-import java.lang.annotation.Documented;
-
-/**
- * Annotation to inform users of how much to rely on a particular package,
- * class or method not changing over time.
- */
-public class InterfaceStability {
-
- /**
- * Stable APIs that retain source and binary compatibility within a major release.
- * These interfaces can change from one major release to another major release
- * (e.g. from 1.0 to 2.0).
- */
- @Documented
- public @interface Stable {};
-
- /**
- * APIs that are meant to evolve towards becoming stable APIs, but are not stable APIs yet.
- * Evolving interfaces can change from one feature release to another release (i.e. 2.1 to 2.2).
- */
- @Documented
- public @interface Evolving {};
-
- /**
- * Unstable APIs, with no guarantee on stability.
- * Classes that are unannotated are considered Unstable.
- */
- @Documented
- public @interface Unstable {};
-}
diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Private.java b/common/tags/src/main/java/org/apache/spark/annotation/Private.java
index 9082fcf0c84bc..a460d608ae16b 100644
--- a/common/tags/src/main/java/org/apache/spark/annotation/Private.java
+++ b/common/tags/src/main/java/org/apache/spark/annotation/Private.java
@@ -17,10 +17,7 @@
package org.apache.spark.annotation;
-import java.lang.annotation.ElementType;
-import java.lang.annotation.Retention;
-import java.lang.annotation.RetentionPolicy;
-import java.lang.annotation.Target;
+import java.lang.annotation.*;
/**
* A class that is considered private to the internals of Spark -- there is a high-likelihood
@@ -35,6 +32,7 @@
* of the known issue that Scaladoc displays only either the annotation or the comment, whichever
* comes first.
*/
+@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Stable.java b/common/tags/src/main/java/org/apache/spark/annotation/Stable.java
new file mode 100644
index 0000000000000..b198bfbe91e10
--- /dev/null
+++ b/common/tags/src/main/java/org/apache/spark/annotation/Stable.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.annotation;
+
+import java.lang.annotation.*;
+
+/**
+ * Stable APIs that retain source and binary compatibility within a major release.
+ * These interfaces can change from one major release to another major release
+ * (e.g. from 1.0 to 2.0).
+ */
+@Documented
+@Retention(RetentionPolicy.RUNTIME)
+@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
+ ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
+public @interface Stable {}
diff --git a/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java b/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java
new file mode 100644
index 0000000000000..88ee72125b23f
--- /dev/null
+++ b/common/tags/src/main/java/org/apache/spark/annotation/Unstable.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.annotation;
+
+import java.lang.annotation.*;
+
+/**
+ * Unstable APIs, with no guarantee on stability.
+ * Classes that are unannotated are considered Unstable.
+ */
+@Documented
+@Retention(RetentionPolicy.RUNTIME)
+@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
+ ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
+public @interface Unstable {}
diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml
index 62c493a5e1ed8..93a4f67fd23f2 100644
--- a/common/unsafe/pom.xml
+++ b/common/unsafe/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-unsafe_2.11
+ spark-unsafe_2.12jarSpark Project Unsafehttp://spark.apache.org/
@@ -89,6 +89,11 @@
commons-lang3test
+
+ org.apache.commons
+ commons-text
+ test
+ target/scala-${scala.binary.version}/classes
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index aca6fca00c48b..4563efcfcf474 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -19,10 +19,10 @@
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
+import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
-import sun.misc.Cleaner;
import sun.misc.Unsafe;
public final class Platform {
@@ -67,6 +67,60 @@ public final class Platform {
unaligned = _unaligned;
}
+ // Access fields and constructors once and store them, for performance:
+
+ private static final Constructor> DBB_CONSTRUCTOR;
+ private static final Field DBB_CLEANER_FIELD;
+ static {
+ try {
+ Class> cls = Class.forName("java.nio.DirectByteBuffer");
+ Constructor> constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE);
+ constructor.setAccessible(true);
+ Field cleanerField = cls.getDeclaredField("cleaner");
+ cleanerField.setAccessible(true);
+ DBB_CONSTRUCTOR = constructor;
+ DBB_CLEANER_FIELD = cleanerField;
+ } catch (ClassNotFoundException | NoSuchMethodException | NoSuchFieldException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ private static final Method CLEANER_CREATE_METHOD;
+ static {
+ // The implementation of Cleaner changed from JDK 8 to 9
+ // Split java.version on non-digit chars:
+ int majorVersion = Integer.parseInt(System.getProperty("java.version").split("\\D+")[0]);
+ String cleanerClassName;
+ if (majorVersion < 9) {
+ cleanerClassName = "sun.misc.Cleaner";
+ } else {
+ cleanerClassName = "jdk.internal.ref.Cleaner";
+ }
+ try {
+ Class> cleanerClass = Class.forName(cleanerClassName);
+ Method createMethod = cleanerClass.getMethod("create", Object.class, Runnable.class);
+ // Accessing jdk.internal.ref.Cleaner should actually fail by default in JDK 9+,
+ // unfortunately, unless the user has allowed access with something like
+ // --add-opens java.base/java.lang=ALL-UNNAMED If not, we can't really use the Cleaner
+ // hack below. It doesn't break, just means the user might run into the default JVM limit
+ // on off-heap memory and increase it or set the flag above. This tests whether it's
+ // available:
+ try {
+ createMethod.invoke(null, null, null);
+ } catch (IllegalAccessException e) {
+ // Don't throw an exception, but can't log here?
+ createMethod = null;
+ } catch (InvocationTargetException ite) {
+ // shouldn't happen; report it
+ throw new IllegalStateException(ite);
+ }
+ CLEANER_CREATE_METHOD = createMethod;
+ } catch (ClassNotFoundException | NoSuchMethodException e) {
+ throw new IllegalStateException(e);
+ }
+
+ }
+
/**
* @return true when running JVM is having sun's Unsafe package available in it and underlying
* system having unaligned-access capability.
@@ -120,6 +174,11 @@ public static float getFloat(Object object, long offset) {
}
public static void putFloat(Object object, long offset, float value) {
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ } else if (value == -0.0f) {
+ value = 0.0f;
+ }
_UNSAFE.putFloat(object, offset, value);
}
@@ -128,6 +187,11 @@ public static double getDouble(Object object, long offset) {
}
public static void putDouble(Object object, long offset, double value) {
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ } else if (value == -0.0d) {
+ value = 0.0d;
+ }
_UNSAFE.putDouble(object, offset, value);
}
@@ -159,18 +223,18 @@ public static long reallocateMemory(long address, long oldSize, long newSize) {
* MaxDirectMemorySize limit (the default limit is too low and we do not want to require users
* to increase it).
*/
- @SuppressWarnings("unchecked")
public static ByteBuffer allocateDirectBuffer(int size) {
try {
- Class> cls = Class.forName("java.nio.DirectByteBuffer");
- Constructor> constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE);
- constructor.setAccessible(true);
- Field cleanerField = cls.getDeclaredField("cleaner");
- cleanerField.setAccessible(true);
long memory = allocateMemory(size);
- ByteBuffer buffer = (ByteBuffer) constructor.newInstance(memory, size);
- Cleaner cleaner = Cleaner.create(buffer, () -> freeMemory(memory));
- cleanerField.set(buffer, cleaner);
+ ByteBuffer buffer = (ByteBuffer) DBB_CONSTRUCTOR.newInstance(memory, size);
+ if (CLEANER_CREATE_METHOD != null) {
+ try {
+ DBB_CLEANER_FIELD.set(buffer,
+ CLEANER_CREATE_METHOD.invoke(null, buffer, (Runnable) () -> freeMemory(memory)));
+ } catch (IllegalAccessException | InvocationTargetException e) {
+ throw new IllegalStateException(e);
+ }
+ }
return buffer;
} catch (Exception e) {
throwException(e);
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java
index be62e40412f83..546e8780a6606 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java
@@ -39,7 +39,9 @@ public static int getSize(Object object, long offset) {
case 8:
return (int)Platform.getLong(object, offset);
default:
+ // checkstyle.off: RegexpSinglelineJava
throw new AssertionError("Illegal UAO_SIZE");
+ // checkstyle.on: RegexpSinglelineJava
}
}
@@ -52,7 +54,9 @@ public static void putSize(Object object, long offset, int value) {
Platform.putLong(object, offset, value);
break;
default:
+ // checkstyle.off: RegexpSinglelineJava
throw new AssertionError("Illegal UAO_SIZE");
+ // checkstyle.on: RegexpSinglelineJava
}
}
}
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index 3ad9ac7b4de9c..2474081dad5c9 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -157,4 +157,22 @@ public void heapMemoryReuse() {
Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
Assert.assertEquals(obj3, onheap4.getBaseObject());
}
+
+ @Test
+ // SPARK-26021
+ public void writeMinusZeroIsReplacedWithZero() {
+ byte[] doubleBytes = new byte[Double.BYTES];
+ byte[] floatBytes = new byte[Float.BYTES];
+ Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
+ Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
+
+ byte[] doubleBytes2 = new byte[Double.BYTES];
+ byte[] floatBytes2 = new byte[Float.BYTES];
+ Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, 0.0d);
+ Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, 0.0f);
+
+ // Make sure the bytes we write from 0.0 and -0.0 are same.
+ Assert.assertArrayEquals(doubleBytes, doubleBytes2);
+ Assert.assertArrayEquals(floatBytes, floatBytes2);
+ }
}
diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
index 9656951810daf..fdb81a06d41c9 100644
--- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
+++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.unsafe.types
-import org.apache.commons.lang3.StringUtils
+import org.apache.commons.text.similarity.LevenshteinDistance
import org.scalacheck.{Arbitrary, Gen}
import org.scalatest.prop.GeneratorDrivenPropertyChecks
// scalastyle:off
@@ -232,7 +232,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty
test("levenshteinDistance") {
forAll { (one: String, another: String) =>
assert(toUTF8(one).levenshteinDistance(toUTF8(another)) ===
- StringUtils.getLevenshteinDistance(one, another))
+ LevenshteinDistance.getDefaultInstance.apply(one, another))
}
}
diff --git a/core/benchmarks/KryoSerializerBenchmark-results.txt b/core/benchmarks/KryoSerializerBenchmark-results.txt
new file mode 100644
index 0000000000000..c3ce336d93241
--- /dev/null
+++ b/core/benchmarks/KryoSerializerBenchmark-results.txt
@@ -0,0 +1,12 @@
+================================================================================================
+Benchmark KryoPool vs "pool of 1"
+================================================================================================
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_131-b11 on Mac OS X 10.14
+Intel(R) Core(TM) i7-4770HQ CPU @ 2.20GHz
+Benchmark KryoPool vs "pool of 1": Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+KryoPool:true 2682 / 3425 0.0 5364627.9 1.0X
+KryoPool:false 8176 / 9292 0.0 16351252.2 0.3X
+
+
diff --git a/core/pom.xml b/core/pom.xml
index 1e9f428d86860..c374862a5d64e 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -20,12 +20,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../pom.xml
- spark-core_2.11
+ spark-core_2.12core
diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
index f6d1288cb263d..92bf0ecc1b5cb 100644
--- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
+++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
@@ -27,7 +27,7 @@
* to read a file to avoid extra copy of data between Java and
* native memory which happens when using {@link java.io.BufferedInputStream}.
* Unfortunately, this is not something already available in JDK,
- * {@link sun.nio.ch.ChannelInputStream} supports reading a file using nio,
+ * {@code sun.nio.ch.ChannelInputStream} supports reading a file using nio,
* but does not support buffering.
*/
public final class NioBufferedFileInputStream extends InputStream {
@@ -130,6 +130,7 @@ public synchronized void close() throws IOException {
StorageUtils.dispose(byteBuffer);
}
+ @SuppressWarnings("deprecation")
@Override
protected void finalize() throws IOException {
close();
diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
index 115e1fbb79a2e..4bfd2d358f36f 100644
--- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
+++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
@@ -83,10 +83,10 @@ public void spill() throws IOException {
public abstract long spill(long size, MemoryConsumer trigger) throws IOException;
/**
- * Allocates a LongArray of `size`. Note that this method may throw `OutOfMemoryError` if Spark
- * doesn't have enough memory for this allocation, or throw `TooLargePageException` if this
- * `LongArray` is too large to fit in a single page. The caller side should take care of these
- * two exceptions, or make sure the `size` is small enough that won't trigger exceptions.
+ * Allocates a LongArray of `size`. Note that this method may throw `SparkOutOfMemoryError`
+ * if Spark doesn't have enough memory for this allocation, or throw `TooLargePageException`
+ * if this `LongArray` is too large to fit in a single page. The caller side should take care of
+ * these two exceptions, or make sure the `size` is small enough that won't trigger exceptions.
*
* @throws SparkOutOfMemoryError
* @throws TooLargePageException
@@ -111,7 +111,7 @@ public void freeArray(LongArray array) {
/**
* Allocate a memory block with at least `required` bytes.
*
- * @throws OutOfMemoryError
+ * @throws SparkOutOfMemoryError
*/
protected MemoryBlock allocatePage(long required) {
MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this);
@@ -154,7 +154,9 @@ private void throwOom(final MemoryBlock page, final long required) {
taskMemoryManager.freePage(page, this);
}
taskMemoryManager.showMemoryUsage();
+ // checkstyle.off: RegexpSinglelineJava
throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " +
got);
+ // checkstyle.on: RegexpSinglelineJava
}
}
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index d07faf1da1248..28b646ba3c951 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -194,8 +194,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
throw new RuntimeException(e.getMessage());
} catch (IOException e) {
logger.error("error while calling spill() on " + c, e);
+ // checkstyle.off: RegexpSinglelineJava
throw new SparkOutOfMemoryError("error while calling spill() on " + c + " : "
+ e.getMessage());
+ // checkstyle.on: RegexpSinglelineJava
}
}
}
@@ -215,8 +217,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
throw new RuntimeException(e.getMessage());
} catch (IOException e) {
logger.error("error while calling spill() on " + consumer, e);
+ // checkstyle.off: RegexpSinglelineJava
throw new SparkOutOfMemoryError("error while calling spill() on " + consumer + " : "
+ e.getMessage());
+ // checkstyle.on: RegexpSinglelineJava
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index b020a6d99247b..fda33cd8293d5 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -37,12 +37,11 @@
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
-import org.apache.spark.TaskContext;
-import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
@@ -79,7 +78,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
private final int numPartitions;
private final BlockManager blockManager;
private final Partitioner partitioner;
- private final ShuffleWriteMetrics writeMetrics;
+ private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
private final Serializer serializer;
@@ -103,8 +102,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
IndexShuffleBlockResolver shuffleBlockResolver,
BypassMergeSortShuffleHandle handle,
int mapId,
- TaskContext taskContext,
- SparkConf conf) {
+ SparkConf conf,
+ ShuffleWriteMetricsReporter writeMetrics) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
@@ -114,7 +113,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
this.shuffleId = dep.shuffleId();
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
- this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
+ this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
this.shuffleBlockResolver = shuffleBlockResolver;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 1c0d664afb138..6ee9d5f0eec3b 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -38,6 +38,7 @@
import org.apache.spark.memory.TooLargePageException;
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.FileSegment;
@@ -75,7 +76,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
- private final ShuffleWriteMetrics writeMetrics;
+ private final ShuffleWriteMetricsReporter writeMetrics;
/**
* Force this sorter to spill when there are this many elements in memory.
@@ -113,7 +114,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
int initialSize,
int numPartitions,
SparkConf conf,
- ShuffleWriteMetrics writeMetrics) {
+ ShuffleWriteMetricsReporter writeMetrics) {
super(memoryManager,
(int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()),
memoryManager.getTungstenMemoryMode());
@@ -144,7 +145,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
*/
private void writeSortedFile(boolean isLastFile) {
- final ShuffleWriteMetrics writeMetricsToUse;
+ final ShuffleWriteMetricsReporter writeMetricsToUse;
if (isLastFile) {
// We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
@@ -241,9 +242,14 @@ private void writeSortedFile(boolean isLastFile) {
//
// Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
// Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
- // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
- writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten());
- taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten());
+ // SPARK-3577 tracks the spill time separately.
+
+ // This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning
+ // of this method.
+ writeMetrics.incRecordsWritten(
+ ((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten());
+ taskContext.taskMetrics().incDiskBytesSpilled(
+ ((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten());
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 4839d04522f10..4b0c74341551e 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -37,7 +37,6 @@
import org.apache.spark.*;
import org.apache.spark.annotation.Private;
-import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.NioBufferedFileInputStream;
@@ -47,6 +46,7 @@
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
@@ -73,7 +73,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
private final TaskMemoryManager memoryManager;
private final SerializerInstance serializer;
private final Partitioner partitioner;
- private final ShuffleWriteMetrics writeMetrics;
+ private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
private final TaskContext taskContext;
@@ -122,7 +122,8 @@ public UnsafeShuffleWriter(
SerializedShuffleHandle handle,
int mapId,
TaskContext taskContext,
- SparkConf sparkConf) throws IOException {
+ SparkConf sparkConf,
+ ShuffleWriteMetricsReporter writeMetrics) throws IOException {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
@@ -138,7 +139,7 @@ public UnsafeShuffleWriter(
this.shuffleId = dep.shuffleId();
this.serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
- this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
+ this.writeMetrics = writeMetrics;
this.taskContext = taskContext;
this.sparkConf = sparkConf;
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
index 5d0555a8c28e1..fcba3b73445c9 100644
--- a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
+++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
@@ -21,7 +21,7 @@
import java.io.OutputStream;
import org.apache.spark.annotation.Private;
-import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
/**
* Intercepts write calls and tracks total time spent writing in order to update shuffle write
@@ -30,10 +30,11 @@
@Private
public final class TimeTrackingOutputStream extends OutputStream {
- private final ShuffleWriteMetrics writeMetrics;
+ private final ShuffleWriteMetricsReporter writeMetrics;
private final OutputStream outputStream;
- public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) {
+ public TimeTrackingOutputStream(
+ ShuffleWriteMetricsReporter writeMetrics, OutputStream outputStream) {
this.writeMetrics = writeMetrics;
this.outputStream = outputStream;
}
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 9b6cbab38cbcc..a4e88598f7607 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -31,6 +31,7 @@
import org.apache.spark.SparkEnv;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.SparkOutOfMemoryError;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockManager;
@@ -741,7 +742,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff
if (numKeys >= growthThreshold && longArray.size() < MAX_CAPACITY) {
try {
growAndRehash();
- } catch (OutOfMemoryError oom) {
+ } catch (SparkOutOfMemoryError oom) {
canGrowArray = false;
}
}
@@ -757,7 +758,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff
private boolean acquireNewPage(long required) {
try {
currentPage = allocatePage(required);
- } catch (OutOfMemoryError e) {
+ } catch (SparkOutOfMemoryError e) {
return false;
}
dataPages.add(currentPage);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 75690ae264838..1a9453a8b3e80 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -214,7 +214,9 @@ public boolean hasSpaceForAnotherRecord() {
public void expandPointerArray(LongArray newArray) {
if (newArray.size() < array.size()) {
+ // checkstyle.off: RegexpSinglelineJava
throw new SparkOutOfMemoryError("Not enough memory to grow pointer array");
+ // checkstyle.on: RegexpSinglelineJava
}
Platform.copyMemory(
array.getBaseObject(),
diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html
index 5c91304e49fd7..f2c17aef097a4 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html
+++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html
@@ -16,10 +16,10 @@
-->
diff --git a/core/src/main/resources/org/apache/spark/ui/static/utils.js b/core/src/main/resources/org/apache/spark/ui/static/utils.js
index 4f63f6413d6de..deeafad4eb5f5 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/utils.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/utils.js
@@ -18,7 +18,7 @@
// this function works exactly the same as UIUtils.formatDuration
function formatDuration(milliseconds) {
if (milliseconds < 100) {
- return milliseconds + " ms";
+ return parseInt(milliseconds).toFixed(1) + " ms";
}
var seconds = milliseconds * 1.0 / 1000;
if (seconds < 1) {
@@ -74,3 +74,114 @@ function getTimeZone() {
return new Date().toString().match(/\((.*)\)/)[1];
}
}
+
+function formatLogsCells(execLogs, type) {
+ if (type !== 'display') return Object.keys(execLogs);
+ if (!execLogs) return;
+ var result = '';
+ $.each(execLogs, function (logName, logUrl) {
+ result += '
else Nil)
-
- val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile",
- "Max")
- // The summary table does not use CSS to stripe rows, which doesn't work with hidden
- // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows).
- UIUtils.listingTable(
- quantileHeaders,
- identity[Seq[Node]],
- listings,
- fixedWidth = true,
- id = Some("task-summary-table"),
- stripeRowsWithCss = false)
- }
-
- val executorTable = new ExecutorTable(stageData, parent.store)
-
- val maybeAccumulableTable: Seq[Node] =
- if (hasAccumulators(stageData)) {
-
{rdd.name}
diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
index 21acaa95c5645..f4d6c7a28d2e4 100644
--- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
+++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
@@ -25,11 +25,14 @@ private[spark]
abstract class CompletionIterator[ +A, +I <: Iterator[A]](sub: I) extends Iterator[A] {
private[this] var completed = false
- def next(): A = sub.next()
+ private[this] var iter = sub
+ def next(): A = iter.next()
def hasNext: Boolean = {
- val r = sub.hasNext
+ val r = iter.hasNext
if (!r && !completed) {
completed = true
+ // reassign to release resources of highly resource consuming iterators early
+ iter = Iterator.empty.asInstanceOf[I]
completion()
}
r
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 93b5826f8a74b..227c9e734f0af 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -31,7 +31,6 @@ import java.security.SecureRandom
import java.util.{Locale, Properties, Random, UUID}
import java.util.concurrent._
import java.util.concurrent.TimeUnit.NANOSECONDS
-import java.util.concurrent.atomic.AtomicBoolean
import java.util.zip.GZIPInputStream
import scala.annotation.tailrec
@@ -93,53 +92,6 @@ private[spark] object Utils extends Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
@volatile private var localRootDirs: Array[String] = null
- /**
- * The performance overhead of creating and logging strings for wide schemas can be large. To
- * limit the impact, we bound the number of fields to include by default. This can be overridden
- * by setting the 'spark.debug.maxToStringFields' conf in SparkEnv.
- */
- val DEFAULT_MAX_TO_STRING_FIELDS = 25
-
- private[spark] def maxNumToStringFields = {
- if (SparkEnv.get != null) {
- SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS)
- } else {
- DEFAULT_MAX_TO_STRING_FIELDS
- }
- }
-
- /** Whether we have warned about plan string truncation yet. */
- private val truncationWarningPrinted = new AtomicBoolean(false)
-
- /**
- * Format a sequence with semantics similar to calling .mkString(). Any elements beyond
- * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder.
- *
- * @return the trimmed and formatted string.
- */
- def truncatedString[T](
- seq: Seq[T],
- start: String,
- sep: String,
- end: String,
- maxNumFields: Int = maxNumToStringFields): String = {
- if (seq.length > maxNumFields) {
- if (truncationWarningPrinted.compareAndSet(false, true)) {
- logWarning(
- "Truncated the string representation of a plan since it was too large. This " +
- "behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.")
- }
- val numFields = math.max(0, maxNumFields - 1)
- seq.take(numFields).mkString(
- start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end)
- } else {
- seq.mkString(start, sep, end)
- }
- }
-
- /** Shorthand for calling truncatedString() without start or end strings. */
- def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "")
-
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -2328,7 +2280,12 @@ private[spark] object Utils extends Logging {
* configure a new log4j level
*/
def setLogLevel(l: org.apache.log4j.Level) {
- org.apache.log4j.Logger.getRootLogger().setLevel(l)
+ val rootLogger = org.apache.log4j.Logger.getRootLogger()
+ rootLogger.setLevel(l)
+ rootLogger.getAllAppenders().asScala.foreach {
+ case ca: org.apache.log4j.ConsoleAppender => ca.setThreshold(l)
+ case _ => // no-op
+ }
}
/**
@@ -2430,7 +2387,8 @@ private[spark] object Utils extends Logging {
"org.apache.spark.security.ShellBasedGroupsMappingProvider")
if (groupProviderClassName != "") {
try {
- val groupMappingServiceProvider = classForName(groupProviderClassName).newInstance.
+ val groupMappingServiceProvider = classForName(groupProviderClassName).
+ getConstructor().newInstance().
asInstanceOf[org.apache.spark.security.GroupMappingServiceProvider]
val currentUserGroups = groupMappingServiceProvider.getGroups(username)
return currentUserGroups
@@ -2863,6 +2821,14 @@ private[spark] object Utils extends Logging {
def stringHalfWidth(str: String): Int = {
if (str == null) 0 else str.length + fullWidthRegex.findAllIn(str).size
}
+
+ def sanitizeDirName(str: String): String = {
+ str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase(Locale.ROOT)
+ }
+
+ def isClientMode(conf: SparkConf): Boolean = {
+ "client".equals(conf.get(SparkLauncher.DEPLOY_MODE, "client"))
+ }
}
private[util] object CallerContext extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala
index 828153b868420..c0f8866dd58dc 100644
--- a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala
@@ -23,6 +23,7 @@ package org.apache.spark.util
private[spark] object VersionUtils {
private val majorMinorRegex = """^(\d+)\.(\d+)(\..*)?$""".r
+ private val shortVersionRegex = """^(\d+\.\d+\.\d+)(.*)?$""".r
/**
* Given a Spark version string, return the major version number.
@@ -36,6 +37,19 @@ private[spark] object VersionUtils {
*/
def minorVersion(sparkVersion: String): Int = majorMinorVersion(sparkVersion)._2
+ /**
+ * Given a Spark version string, return the short version string.
+ * E.g., for 3.0.0-SNAPSHOT, return '3.0.0'.
+ */
+ def shortVersion(sparkVersion: String): String = {
+ shortVersionRegex.findFirstMatchIn(sparkVersion) match {
+ case Some(m) => m.group(1)
+ case None =>
+ throw new IllegalArgumentException(s"Spark tried to parse '$sparkVersion' as a Spark" +
+ s" version string, but it could not find the major/minor/maintenance version numbers.")
+ }
+ }
+
/**
* Given a Spark version string, return the (major version number, minor version number).
* E.g., for 2.0.1-SNAPSHOT, return (2, 0).
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index b159200d79222..46279e79d78db 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -727,9 +727,10 @@ private[spark] class ExternalSorter[K, V, C](
spills.clear()
forceSpillFiles.foreach(s => s.file.delete())
forceSpillFiles.clear()
- if (map != null || buffer != null) {
+ if (map != null || buffer != null || readingIterator != null) {
map = null // So that the memory can be garbage-collected
buffer = null // So that the memory can be garbage-collected
+ readingIterator = null // So that the memory can be garbage-collected
releaseMemory()
}
}
@@ -793,8 +794,8 @@ private[spark] class ExternalSorter[K, V, C](
def nextPartition(): Int = cur._1._1
}
- logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
- s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
+ logInfo(s"Task ${TaskContext.get().taskAttemptId} force spilling in-memory map to disk " +
+ s"and it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
forceSpillFiles += spillFile
val spillReader = new SpillReader(spillFile)
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index 870830fff4c3e..128d6ff8cd746 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -222,7 +222,8 @@ private[spark] class ChunkedByteBufferInputStream(
dispose: Boolean)
extends InputStream {
- private[this] var chunks = chunkedByteBuffer.getChunks().iterator
+ // Filter out empty chunks since `read()` assumes all chunks are non-empty.
+ private[this] var chunks = chunkedByteBuffer.getChunks().filter(_.hasRemaining).iterator
private[this] var currentChunk: ByteBuffer = {
if (chunks.hasNext) {
chunks.next()
diff --git a/core/src/main/scala/org/apache/spark/util/logging/DriverLogger.scala b/core/src/main/scala/org/apache/spark/util/logging/DriverLogger.scala
new file mode 100644
index 0000000000000..bea18a3df4783
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/logging/DriverLogger.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.logging
+
+import java.io._
+import java.util.concurrent.{ScheduledExecutorService, TimeUnit}
+
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path}
+import org.apache.hadoop.fs.permission.FsPermission
+import org.apache.log4j.{FileAppender => Log4jFileAppender, _}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+private[spark] class DriverLogger(conf: SparkConf) extends Logging {
+
+ private val UPLOAD_CHUNK_SIZE = 1024 * 1024
+ private val UPLOAD_INTERVAL_IN_SECS = 5
+ private val DEFAULT_LAYOUT = "%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n"
+ private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
+
+ private var localLogFile: String = FileUtils.getFile(
+ Utils.getLocalDir(conf),
+ DriverLogger.DRIVER_LOG_DIR,
+ DriverLogger.DRIVER_LOG_FILE).getAbsolutePath()
+ private var writer: Option[DfsAsyncWriter] = None
+
+ addLogAppender()
+
+ private def addLogAppender(): Unit = {
+ val appenders = LogManager.getRootLogger().getAllAppenders()
+ val layout = if (conf.contains(DRIVER_LOG_LAYOUT)) {
+ new PatternLayout(conf.get(DRIVER_LOG_LAYOUT).get)
+ } else if (appenders.hasMoreElements()) {
+ appenders.nextElement().asInstanceOf[Appender].getLayout()
+ } else {
+ new PatternLayout(DEFAULT_LAYOUT)
+ }
+ val fa = new Log4jFileAppender(layout, localLogFile)
+ fa.setName(DriverLogger.APPENDER_NAME)
+ LogManager.getRootLogger().addAppender(fa)
+ logInfo(s"Added a local log appender at: ${localLogFile}")
+ }
+
+ def startSync(hadoopConf: Configuration): Unit = {
+ try {
+ // Setup a writer which moves the local file to hdfs continuously
+ val appId = Utils.sanitizeDirName(conf.getAppId)
+ writer = Some(new DfsAsyncWriter(appId, hadoopConf))
+ } catch {
+ case e: Exception =>
+ logError(s"Could not persist driver logs to dfs", e)
+ }
+ }
+
+ def stop(): Unit = {
+ try {
+ val fa = LogManager.getRootLogger.getAppender(DriverLogger.APPENDER_NAME)
+ LogManager.getRootLogger().removeAppender(DriverLogger.APPENDER_NAME)
+ Utils.tryLogNonFatalError(fa.close())
+ writer.foreach(_.closeWriter())
+ } catch {
+ case e: Exception =>
+ logError(s"Error in persisting driver logs", e)
+ } finally {
+ Utils.tryLogNonFatalError {
+ JavaUtils.deleteRecursively(FileUtils.getFile(localLogFile).getParentFile())
+ }
+ }
+ }
+
+ // Visible for testing
+ private[spark] class DfsAsyncWriter(appId: String, hadoopConf: Configuration) extends Runnable
+ with Logging {
+
+ private var streamClosed = false
+ private var inStream: InputStream = null
+ private var outputStream: FSDataOutputStream = null
+ private val tmpBuffer = new Array[Byte](UPLOAD_CHUNK_SIZE)
+ private var threadpool: ScheduledExecutorService = _
+ init()
+
+ private def init(): Unit = {
+ val rootDir = conf.get(DRIVER_LOG_DFS_DIR).get
+ val fileSystem: FileSystem = new Path(rootDir).getFileSystem(hadoopConf)
+ if (!fileSystem.exists(new Path(rootDir))) {
+ throw new RuntimeException(s"${rootDir} does not exist." +
+ s" Please create this dir in order to persist driver logs")
+ }
+ val dfsLogFile: String = FileUtils.getFile(rootDir, appId
+ + DriverLogger.DRIVER_LOG_FILE_SUFFIX).getAbsolutePath()
+ try {
+ inStream = new BufferedInputStream(new FileInputStream(localLogFile))
+ outputStream = fileSystem.create(new Path(dfsLogFile), true)
+ fileSystem.setPermission(new Path(dfsLogFile), LOG_FILE_PERMISSIONS)
+ } catch {
+ case e: Exception =>
+ JavaUtils.closeQuietly(inStream)
+ JavaUtils.closeQuietly(outputStream)
+ throw e
+ }
+ threadpool = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dfsSyncThread")
+ threadpool.scheduleWithFixedDelay(this, UPLOAD_INTERVAL_IN_SECS, UPLOAD_INTERVAL_IN_SECS,
+ TimeUnit.SECONDS)
+ logInfo(s"Started driver log file sync to: ${dfsLogFile}")
+ }
+
+ def run(): Unit = {
+ if (streamClosed) {
+ return
+ }
+ try {
+ var remaining = inStream.available()
+ while (remaining > 0) {
+ val read = inStream.read(tmpBuffer, 0, math.min(remaining, UPLOAD_CHUNK_SIZE))
+ outputStream.write(tmpBuffer, 0, read)
+ remaining -= read
+ }
+ outputStream.hflush()
+ } catch {
+ case e: Exception => logError("Failed writing driver logs to dfs", e)
+ }
+ }
+
+ private def close(): Unit = {
+ if (streamClosed) {
+ return
+ }
+ try {
+ // Write all remaining bytes
+ run()
+ } finally {
+ try {
+ streamClosed = true
+ inStream.close()
+ outputStream.close()
+ } catch {
+ case e: Exception =>
+ logError("Error in closing driver log input/output stream", e)
+ }
+ }
+ }
+
+ def closeWriter(): Unit = {
+ try {
+ threadpool.execute(new Runnable() {
+ override def run(): Unit = DfsAsyncWriter.this.close()
+ })
+ threadpool.shutdown()
+ threadpool.awaitTermination(1, TimeUnit.MINUTES)
+ } catch {
+ case e: Exception =>
+ logError("Error in shutting down threadpool", e)
+ }
+ }
+ }
+
+}
+
+private[spark] object DriverLogger extends Logging {
+ val DRIVER_LOG_DIR = "__driver_logs__"
+ val DRIVER_LOG_FILE = "driver.log"
+ val DRIVER_LOG_FILE_SUFFIX = "_" + DRIVER_LOG_FILE
+ val APPENDER_NAME = "_DriverLogAppender"
+
+ def apply(conf: SparkConf): Option[DriverLogger] = {
+ if (conf.get(DRIVER_LOG_PERSISTTODFS) && Utils.isClientMode(conf)) {
+ if (conf.contains(DRIVER_LOG_DFS_DIR)) {
+ try {
+ Some(new DriverLogger(conf))
+ } catch {
+ case e: Exception =>
+ logError("Could not add driver logger", e)
+ None
+ }
+ } else {
+ logWarning(s"Driver logs are not persisted because" +
+ s" ${DRIVER_LOG_DFS_DIR.key} is not configured")
+ None
+ }
+ } else {
+ None
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index ea99a7e5b4847..70554f1d03067 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -49,7 +49,7 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
/** return a copy of the RandomSampler object */
override def clone: RandomSampler[T, U] =
- throw new NotImplementedError("clone() is not implemented.")
+ throw new UnsupportedOperationException("clone() is not implemented.")
}
private[spark]
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index a07d0e84ea854..30ad3f5575545 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -162,7 +162,8 @@ private UnsafeShuffleWriter
+
+
spark.driver.log.dfsDir
+
(none)
+
+ Base directory in which Spark driver logs are synced, if spark.driver.log.persistToDfs.enabled
+ is true. Within this base directory, each application logs the driver logs to an application specific file.
+ Users may want to set this to a unified location like an HDFS directory so driver log files can be persisted
+ for later usage. This directory should allow any Spark user to read/write files and the Spark History Server
+ user to delete files. Additionally, older logs from this directory are cleaned by the
+ Spark History Server if
+ spark.history.fs.driverlog.cleaner.enabled is true and, if they are older than max age configured
+ by setting spark.history.fs.driverlog.cleaner.maxAge.
+
+
+
+
spark.driver.log.persistToDfs.enabled
+
false
+
+ If true, spark application running in client mode will write driver logs to a persistent storage, configured
+ in spark.driver.log.dfsDir. If spark.driver.log.dfsDir is not configured, driver logs
+ will not be persisted. Additionally, enable the cleaner by setting spark.history.fs.driverlog.cleaner.enabled
+ to true in Spark History Server.
+
+
+
+
spark.driver.log.layout
+
%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+ The layout for the driver logs that are synced to spark.driver.log.dfsDir. If this is not configured,
+ it uses the layout for the first appender defined in log4j.properties. If that is also not configured, driver logs
+ use the default layout.
+
+
Apart from these, the following properties are also available, and may be useful in some situations:
@@ -940,6 +973,14 @@ Apart from these, the following properties are also available, and may be useful
spark.com.test.filter1.param.name2=bar
+
+
spark.ui.requestHeaderSize
+
8k
+
+ The maximum allowed size for a HTTP request header, in bytes unless otherwise specified.
+ This setting applies for the Spark History Server too.
+
+
### Compression and Serialization
diff --git a/docs/index.md b/docs/index.md
index b9996cc8645d9..b842fca6245ae 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -31,7 +31,8 @@ Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy
locally on one machine --- all you need is to have `java` installed on your system `PATH`,
or the `JAVA_HOME` environment variable pointing to a Java installation.
-Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}}
+Spark runs on Java 8+, Python 2.7+/3.4+ and R 3.1+. R prior to version 3.4 support is deprecated as of Spark 3.0.0.
+For the Scala API, Spark {{site.SPARK_VERSION}}
uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version
({{site.SCALA_BINARY_VERSION}}.x).
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index b3d109039da4d..42912a2e2bc31 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -941,9 +941,9 @@ Essentially isotonic regression is a
best fitting the original data points.
We implement a
-[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111)
+[pool adjacent violators algorithm](https://doi.org/10.1198/TECH.2010.10111)
which uses an approach to
-[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10).
+[parallelizing isotonic regression](https://doi.org/10.1007/978-3-642-99789-1_10).
The training input is a DataFrame which contains three columns
label, features and weight. Additionally, IsotonicRegression algorithm has one
optional parameter called $isotonic$ defaulting to true.
diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md
index 8b0f287dc39ad..58646642bfbcc 100644
--- a/docs/ml-collaborative-filtering.md
+++ b/docs/ml-collaborative-filtering.md
@@ -41,7 +41,7 @@ for example, users giving ratings to movies.
It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views,
clicks, purchases, likes, shares etc.). The approach used in `spark.ml` to deal with such data is taken
-from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22).
+from [Collaborative Filtering for Implicit Feedback Datasets](https://doi.org/10.1109/ICDM.2008.22).
Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data
as numbers representing the *strength* in observations of user actions (such as the number of clicks,
or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of
@@ -55,7 +55,7 @@ We scale the regularization parameter `regParam` in solving each least squares p
the number of ratings the user generated in updating user factors,
or the number of ratings the product received in updating product factors.
This approach is named "ALS-WR" and discussed in the paper
-"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)".
+"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](https://doi.org/10.1007/978-3-540-68880-8_32)".
It makes `regParam` less dependent on the scale of the dataset, so we can apply the
best parameter learned from a sampled subset to the full dataset and expect similar performance.
diff --git a/docs/ml-frequent-pattern-mining.md b/docs/ml-frequent-pattern-mining.md
index c2043d495c149..f613664271ec6 100644
--- a/docs/ml-frequent-pattern-mining.md
+++ b/docs/ml-frequent-pattern-mining.md
@@ -18,7 +18,7 @@ for more information.
## FP-Growth
The FP-growth algorithm is described in the paper
-[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372),
+[Han et al., Mining frequent patterns without candidate generation](https://doi.org/10.1145/335191.335372),
where "FP" stands for frequent pattern.
Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items.
Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose,
@@ -26,7 +26,7 @@ the second step of FP-growth uses a suffix tree (FP-tree) structure to encode tr
explicitly, which are usually expensive to generate.
After the second step, the frequent itemsets can be extracted from the FP-tree.
In `spark.mllib`, we implemented a parallel version of FP-growth called PFP,
-as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027).
+as described in [Li et al., PFP: Parallel FP-growth for query recommendation](https://doi.org/10.1145/1454008.1454027).
PFP distributes the work of growing FP-trees based on the suffixes of transactions,
and hence is more scalable than a single-machine implementation.
We refer users to the papers for more details.
@@ -90,7 +90,7 @@ Refer to the [R API docs](api/R/spark.fpGrowth.html) for more details.
PrefixSpan is a sequential pattern mining algorithm described in
[Pei et al., Mining Sequential Patterns by Pattern-Growth: The
-PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer
+PrefixSpan Approach](https://doi.org/10.1109%2FTKDE.2004.77). We refer
the reader to the referenced paper for formalizing the sequential
pattern mining problem.
@@ -137,4 +137,4 @@ Refer to the [R API docs](api/R/spark.prefixSpan.html) for more details.
{% include_example r/ml/prefixSpan.R %}
-
\ No newline at end of file
+
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md
index b2300028e151b..aeebb26bb45f3 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -37,7 +37,7 @@ for example, users giving ratings to movies.
It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views,
clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken
-from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22).
+from [Collaborative Filtering for Implicit Feedback Datasets](https://doi.org/10.1109/ICDM.2008.22).
Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data
as numbers representing the *strength* in observations of user actions (such as the number of clicks,
or the cumulative duration someone spent viewing a movie). Those numbers are then related to the level of
@@ -51,7 +51,7 @@ Since v1.1, we scale the regularization parameter `lambda` in solving each least
the number of ratings the user generated in updating user factors,
or the number of ratings the product received in updating product factors.
This approach is named "ALS-WR" and discussed in the paper
-"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](http://dx.doi.org/10.1007/978-3-540-68880-8_32)".
+"[Large-Scale Parallel Collaborative Filtering for the Netflix Prize](https://doi.org/10.1007/978-3-540-68880-8_32)".
It makes `lambda` less dependent on the scale of the dataset, so we can apply the
best parameter learned from a sampled subset to the full dataset and expect similar performance.
diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md
index 0d3192c6b1d9c..8e4505756b275 100644
--- a/docs/mllib-frequent-pattern-mining.md
+++ b/docs/mllib-frequent-pattern-mining.md
@@ -15,7 +15,7 @@ a popular algorithm to mining frequent itemsets.
## FP-growth
The FP-growth algorithm is described in the paper
-[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372),
+[Han et al., Mining frequent patterns without candidate generation](https://doi.org/10.1145/335191.335372),
where "FP" stands for frequent pattern.
Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items.
Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose,
@@ -23,7 +23,7 @@ the second step of FP-growth uses a suffix tree (FP-tree) structure to encode tr
explicitly, which are usually expensive to generate.
After the second step, the frequent itemsets can be extracted from the FP-tree.
In `spark.mllib`, we implemented a parallel version of FP-growth called PFP,
-as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027).
+as described in [Li et al., PFP: Parallel FP-growth for query recommendation](https://doi.org/10.1145/1454008.1454027).
PFP distributes the work of growing FP-trees based on the suffixes of transactions,
and hence more scalable than a single-machine implementation.
We refer users to the papers for more details.
@@ -122,7 +122,7 @@ Refer to the [`AssociationRules` Java docs](api/java/org/apache/spark/mllib/fpm/
PrefixSpan is a sequential pattern mining algorithm described in
[Pei et al., Mining Sequential Patterns by Pattern-Growth: The
-PrefixSpan Approach](http://dx.doi.org/10.1109%2FTKDE.2004.77). We refer
+PrefixSpan Approach](https://doi.org/10.1109%2FTKDE.2004.77). We refer
the reader to the referenced paper for formalizing the sequential
pattern mining problem.
diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md
index 99cab98c690c6..9964fce3273be 100644
--- a/docs/mllib-isotonic-regression.md
+++ b/docs/mllib-isotonic-regression.md
@@ -24,9 +24,9 @@ Essentially isotonic regression is a
best fitting the original data points.
`spark.mllib` supports a
-[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111)
+[pool adjacent violators algorithm](https://doi.org/10.1198/TECH.2010.10111)
which uses an approach to
-[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10).
+[parallelizing isotonic regression](https://doi.org/10.1007/978-3-642-99789-1_10).
The training input is an RDD of tuples of three double values that represent
label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one
optional parameter called $isotonic$ defaulting to true.
diff --git a/docs/monitoring.md b/docs/monitoring.md
index 69bf3082f0f27..6bb620a2e5f69 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -202,6 +202,28 @@ Security options for the Spark History Server are covered more detail in the
applications that fail to rename their event logs listed as in-progress.
+
+
spark.history.fs.driverlog.cleaner.enabled
+
spark.history.fs.cleaner.enabled
+
+ Specifies whether the History Server should periodically clean up driver logs from storage.
+
+
+
+
spark.history.fs.driverlog.cleaner.interval
+
spark.history.fs.cleaner.interval
+
+ How often the filesystem driver log cleaner checks for files to delete.
+ Files are only deleted if they are older than spark.history.fs.driverlog.cleaner.maxAge
+
+
+
+
spark.history.fs.driverlog.cleaner.maxAge
+
spark.history.fs.cleaner.maxAge
+
+ Driver log files older than this will be deleted when the driver log cleaner runs.
+
+
spark.history.fs.numReplayThreads
25% of available cores
diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md
index 29a4d86fecb44..ec9514ea4134a 100644
--- a/docs/running-on-kubernetes.md
+++ b/docs/running-on-kubernetes.md
@@ -15,7 +15,19 @@ container images and entrypoints.**
# Security
Security in Spark is OFF by default. This could mean you are vulnerable to attack by default.
-Please see [Spark Security](security.html) and the specific security sections in this doc before running Spark.
+Please see [Spark Security](security.html) and the specific advice below before running Spark.
+
+## User Identity
+
+Images built from the project provided Dockerfiles do not contain any [`USER`](https://docs.docker.com/engine/reference/builder/#user) directives. This means that the resulting images will be running the Spark processes as `root` inside the container. On unsecured clusters this may provide an attack vector for privilege escalation and container breakout. Therefore security conscious deployments should consider providing custom images with `USER` directives specifying an unprivileged UID and GID.
+
+Alternatively the [Pod Template](#pod-template) feature can be used to add a [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context/#volumes-and-file-systems) with a `runAsUser` to the pods that Spark submits. Please bear in mind that this requires cooperation from your users and as such may not be a suitable solution for shared environments. Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/#users-and-groups) if they wish to limit the users that pods may run as.
+
+## Volume Mounts
+
+As described later in this document under [Using Kubernetes Volumes](#using-kubernetes-volumes) Spark on K8S provides configuration options that allow for mounting certain volume types into the driver and executor pods. In particular it allows for [`hostPath`](https://kubernetes.io/docs/concepts/storage/volumes/#hostpath) volumes which as described in the Kubernetes documentation have known security vulnerabilities.
+
+Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/) to limit the ability to mount `hostPath` volumes appropriately for their environments.
# Prerequisites
@@ -76,6 +88,18 @@ $ ./bin/docker-image-tool.sh -r -t my-tag build
$ ./bin/docker-image-tool.sh -r -t my-tag push
```
+By default `bin/docker-image-tool.sh` builds docker image for running JVM jobs. You need to opt-in to build additional
+language binding docker images.
+
+Example usage is
+```bash
+# To build additional PySpark docker image
+$ ./bin/docker-image-tool.sh -r -t my-tag -p ./kubernetes/dockerfiles/spark/bindings/python/Dockerfile build
+
+# To build additional SparkR docker image
+$ ./bin/docker-image-tool.sh -r -t my-tag -R ./kubernetes/dockerfiles/spark/bindings/R/Dockerfile build
+```
+
## Cluster Mode
To launch Spark Pi in cluster mode,
@@ -142,7 +166,7 @@ hostname via `spark.driver.host` and your spark driver's port to `spark.driver.p
### Client Mode Executor Pod Garbage Collection
-If you run your Spark driver in a pod, it is highly recommended to set `spark.driver.pod.name` to the name of that pod.
+If you run your Spark driver in a pod, it is highly recommended to set `spark.kubernetes.driver.pod.name` to the name of that pod.
When this property is set, the Spark scheduler will deploy the executor pods with an
[OwnerReference](https://kubernetes.io/docs/concepts/workloads/controllers/garbage-collection/), which in turn will
ensure that once the driver pod is deleted from the cluster, all of the application's executor pods will also be deleted.
@@ -151,7 +175,7 @@ an OwnerReference pointing to that pod will be added to each executor pod's Owne
setting the OwnerReference to a pod that is not actually that driver pod, or else the executors may be terminated
prematurely when the wrong pod is deleted.
-If your application is not running inside a pod, or if `spark.driver.pod.name` is not set when your application is
+If your application is not running inside a pod, or if `spark.kubernetes.driver.pod.name` is not set when your application is
actually running in a pod, keep in mind that the executor pods may not be properly deleted from the cluster when the
application exits. The Spark scheduler attempts to delete these pods, but if the network request to the API server fails
for any reason, these pods will remain in the cluster. The executor processes should exit when they cannot reach the
@@ -214,11 +238,14 @@ Starting with Spark 2.4.0, users can mount the following types of Kubernetes [vo
* [emptyDir](https://kubernetes.io/docs/concepts/storage/volumes/#emptydir): an initially empty volume created when a pod is assigned to a node.
* [persistentVolumeClaim](https://kubernetes.io/docs/concepts/storage/volumes/#persistentvolumeclaim): used to mount a `PersistentVolume` into a pod.
+**NB:** Please see the [Security](#security) section of this document for security issues related to volume mounts.
+
To mount a volume of any of the types above into the driver pod, use the following configuration property:
```
--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.path=
--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.readOnly=
+--conf spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].mount.subPath=
```
Specifically, `VolumeType` can be one of the following values: `hostPath`, `emptyDir`, and `persistentVolumeClaim`. `VolumeName` is the name you want to use for the volume under the `volumes` field in the pod specification.
@@ -780,6 +807,14 @@ specific to Spark on Kubernetes.
spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.path=/checkpoint.
+ Specifies a subpath to be mounted from the volume into the driver pod.
+ spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.mount.subPath=checkpoint.
+
+ Specifies a subpath to be mounted from the volume into the executor pod.
+ spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.mount.subPath=checkpoint.
+
diff --git a/docs/security.md b/docs/security.md
index 2f7fa9c6179f4..02d581c6dad91 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -821,3 +821,14 @@ should correspond to the super user who is running the Spark History Server.
This will allow all users to write to the directory but will prevent unprivileged users from
reading, removing or renaming a file unless they own it. The event log files will be created by
Spark with permissions such that only the user and group have read and write access.
+
+# Persisting driver logs in client mode
+
+If your applications persist driver logs in client mode by enabling `spark.driver.log.persistToDfs.enabled`,
+the directory where the driver logs go (`spark.driver.log.dfsDir`) should be manually created with proper
+permissions. To secure the log files, the directory permissions should be set to `drwxrwxrwxt`. The owner
+and group of the directory should correspond to the super user who is running the Spark History Server.
+
+This will allow all users to write to the directory but will prevent unprivileged users from
+reading, removing or renaming a file unless they own it. The driver log files will be created by
+Spark with permissions such that only the user and group have read and write access.
diff --git a/docs/sparkr.md b/docs/sparkr.md
index f84ec504b595a..5972435a0e409 100644
--- a/docs/sparkr.md
+++ b/docs/sparkr.md
@@ -133,7 +133,7 @@ specifying `--packages` with `spark-submit` or `sparkR` commands, or if initiali
diff --git a/docs/sql-data-sources-hive-tables.md b/docs/sql-data-sources-hive-tables.md
index 687e6f8e0a7cc..28e1a39626666 100644
--- a/docs/sql-data-sources-hive-tables.md
+++ b/docs/sql-data-sources-hive-tables.md
@@ -115,7 +115,7 @@ The following options can be used to configure the version of Hive that is used
1.2.1
Version of the Hive metastore. Available
- options are 0.12.0 through 2.3.3.
+ options are 0.12.0 through 2.3.4.
diff --git a/docs/sql-migration-guide-hive-compatibility.md b/docs/sql-migration-guide-hive-compatibility.md
index 94849418030ef..dd7b06225714f 100644
--- a/docs/sql-migration-guide-hive-compatibility.md
+++ b/docs/sql-migration-guide-hive-compatibility.md
@@ -10,7 +10,7 @@ displayTitle: Compatibility with Apache Hive
Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs.
Currently, Hive SerDes and UDFs are based on Hive 1.2.1,
and Spark SQL can be connected to different versions of Hive Metastore
-(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)).
+(from 0.12.0 to 2.3.4. Also see [Interacting with Different Versions of Hive Metastore](sql-data-sources-hive-tables.html#interacting-with-different-versions-of-hive-metastore)).
#### Deploying in Existing Hive Warehouses
diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md
index 80f113de62ccc..25cd541190919 100644
--- a/docs/sql-migration-guide-upgrade.md
+++ b/docs/sql-migration-guide-upgrade.md
@@ -17,8 +17,16 @@ displayTitle: Spark SQL Upgrading Guide
- Since Spark 3.0, the `from_json` functions supports two modes - `PERMISSIVE` and `FAILFAST`. The modes can be set via the `mode` option. The default mode became `PERMISSIVE`. In previous versions, behavior of `from_json` did not conform to either `PERMISSIVE` nor `FAILFAST`, especially in processing of malformed JSON records. For example, the JSON string `{"a" 1}` with the schema `a INT` is converted to `null` by previous versions but Spark 3.0 converts it to `Row(null)`.
+ - In Spark version 2.4 and earlier, the `from_json` function produces `null`s for JSON strings and JSON datasource skips the same independetly of its mode if there is no valid root JSON token in its input (` ` for example). Since Spark 3.0, such input is treated as a bad record and handled according to specified mode. For example, in the `PERMISSIVE` mode the ` ` input is converted to `Row(null, null)` if specified schema is `key STRING, value INT`.
+
- The `ADD JAR` command previously returned a result set with the single value 0. It now returns an empty result set.
+ - In Spark version 2.4 and earlier, users can create map values with map type key via built-in function like `CreateMap`, `MapFromArrays`, etc. Since Spark 3.0, it's not allowed to create map values with map type key with these built-in functions. Users can still read map values with map type key from data source or Java/Scala collections, though they are not very useful.
+
+ - In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`.
+
+ - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more.
+
## Upgrading From Spark SQL 2.3 to 2.4
- In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below.
@@ -119,7 +127,7 @@ displayTitle: Spark SQL Upgrading Guide
- Since Spark 2.4, Metadata files (e.g. Parquet summary files) and temporary files are not counted as data files when calculating table size during Statistics computation.
- - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string.
+ - Since Spark 2.4, empty strings are saved as quoted empty strings `""`. In version 2.3 and earlier, empty strings are equal to `null` values and do not reflect to any characters in saved CSV files. For example, the row of `"a", null, "", 1` was written as `a,,,1`. Since Spark 2.4, the same row is saved as `a,,"",1`. To restore the previous behavior, set the CSV option `emptyValue` to empty (not quoted) string.
- Since Spark 2.4, The LOAD DATA command supports wildcard `?` and `*`, which match any one character, and zero or more characters, respectively. Example: `LOAD DATA INPATH '/tmp/folder*/'` or `LOAD DATA INPATH '/tmp/part-?'`. Special Characters like `space` also now work in paths. Example: `LOAD DATA INPATH '/tmp/folder name/'`.
@@ -305,7 +313,7 @@ displayTitle: Spark SQL Upgrading Guide
## Upgrading From Spark SQL 2.1 to 2.2
- Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time-consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.
-
+
- Since Spark 2.2.1 and 2.3.0, the schema is always inferred at runtime when the data source tables have the columns that exist in both partition schema and data schema. The inferred schema does not have the partitioned columns. When reading the table, Spark respects the partition values of these overlapping columns instead of the values stored in the data source files. In 2.2.0 and 2.1.x release, the inferred schema is partitioned but the data of the table is invisible to users (i.e., the result set is empty).
- Since Spark 2.2, view definitions are stored in a different way from prior versions. This may cause Spark unable to read views created by prior versions. In such cases, you need to recreate the views using `ALTER VIEW AS` or `CREATE OR REPLACE VIEW AS` with newer Spark versions.
diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md
index 71fd5b10cc407..a549ce2a6a05f 100644
--- a/docs/structured-streaming-kafka-integration.md
+++ b/docs/structured-streaming-kafka-integration.md
@@ -123,7 +123,7 @@ df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
-### Creating a Kafka Source for Batch Queries
+### Creating a Kafka Source for Batch Queries
If you have a use case that is better suited to batch processing,
you can create a Dataset/DataFrame for a defined range of offsets.
@@ -374,17 +374,24 @@ The following configurations are optional:
streaming and batch
Rate limit on maximum number of offsets processed per trigger interval. The specified total number of offsets will be proportionally split across topicPartitions of different volume.
+
+
groupIdPrefix
+
string
+
spark-kafka-source
+
streaming and batch
+
Prefix of consumer group identifiers (`group.id`) that are generated by structured streaming queries
+
## Writing Data to Kafka
-Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that
+Here, we describe the support for writing Streaming Queries and Batch Queries to Apache Kafka. Take note that
Apache Kafka only supports at least once write semantics. Consequently, when writing---either Streaming Queries
or Batch Queries---to Kafka, some records may be duplicated; this can happen, for example, if Kafka needs
to retry a message that was not acknowledged by a Broker, even though that Broker received and wrote the message record.
-Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However,
+Structured Streaming cannot prevent such duplicates from occurring due to these Kafka write semantics. However,
if writing the query is successful, then you can assume that the query output was written at least once. A possible
-solution to remove duplicates when reading the written data could be to introduce a primary (unique) key
+solution to remove duplicates when reading the written data could be to introduce a primary (unique) key
that can be used to perform de-duplication when reading.
The Dataframe being written to Kafka should have the following columns in schema:
@@ -405,8 +412,8 @@ The Dataframe being written to Kafka should have the following columns in schema
\* The topic column is required if the "topic" configuration option is not specified.
-The value column is the only required option. If a key column is not specified then
-a ```null``` valued key column will be automatically added (see Kafka semantics on
+The value column is the only required option. If a key column is not specified then
+a ```null``` valued key column will be automatically added (see Kafka semantics on
how ```null``` valued key values are handled). If a topic column exists then its value
is used as the topic when writing the given row to Kafka, unless the "topic" configuration
option is set i.e., the "topic" configuration option overrides the topic column.
@@ -568,7 +575,7 @@ df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \
.format("kafka") \
.option("kafka.bootstrap.servers", "host1:port1,host2:port2") \
.save()
-
+
{% endhighlight %}
@@ -576,23 +583,25 @@ df.selectExpr("topic", "CAST(key AS STRING)", "CAST(value AS STRING)") \
## Kafka Specific Configurations
-Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g,
-`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see
+Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g,
+`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafka parameters, see
[Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs) for
parameters related to reading data, and [Kafka producer config docs](http://kafka.apache.org/documentation/#producerconfigs)
for parameters related to writing data.
Note that the following Kafka params cannot be set and the Kafka source or sink will throw an exception:
-- **group.id**: Kafka source will create a unique group id for each query automatically.
+- **group.id**: Kafka source will create a unique group id for each query automatically. The user can
+set the prefix of the automatically generated group.id's via the optional source option `groupIdPrefix`, default value
+is "spark-kafka-source".
- **auto.offset.reset**: Set the source option `startingOffsets` to specify
- where to start instead. Structured Streaming manages which offsets are consumed internally, rather
- than rely on the kafka Consumer to do it. This will ensure that no data is missed when new
+ where to start instead. Structured Streaming manages which offsets are consumed internally, rather
+ than rely on the kafka Consumer to do it. This will ensure that no data is missed when new
topics/partitions are dynamically subscribed. Note that `startingOffsets` only applies when a new
streaming query is started, and that resuming will always pick up from where the query left off.
-- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use
+- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use
DataFrame operations to explicitly deserialize the keys.
-- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer.
+- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer.
Use DataFrame operations to explicitly deserialize the values.
- **key.serializer**: Keys are always serialized with ByteArraySerializer or StringSerializer. Use
DataFrame operations to explicitly serialize the keys into either strings or byte arrays.
diff --git a/examples/pom.xml b/examples/pom.xml
index e9b1d4be33974..fbc3ec40359f7 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -20,12 +20,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../pom.xml
- spark-examples_2.11
+ spark-examples_2.12jarSpark Project Exampleshttp://spark.apache.org/
diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
index c55b68e033964..03187aee044e4 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala
@@ -32,13 +32,13 @@ object LogQuery {
| GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR
| 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR
| 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.350 "-" - "" 265 923 934 ""
- | 62.24.11.25 images.com 1358492167 - Whatup""".stripMargin.lines.mkString,
+ | 62.24.11.25 images.com 1358492167 - Whatup""".stripMargin.split('\n').mkString,
"""10.10.10.10 - "FRED" [18/Jan/2013:18:02:37 +1100] "GET http://images.com/2013/Generic.jpg
| HTTP/1.1" 304 306 "http:/referall.com" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1;
| GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR
| 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR
| 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.352 "-" - "" 256 977 988 ""
- | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.lines.mkString
+ | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.split('\n').mkString
)
def main(args: Array[String]) {
diff --git a/external/avro/benchmarks/AvroReadBenchmark-results.txt b/external/avro/benchmarks/AvroReadBenchmark-results.txt
new file mode 100644
index 0000000000000..7900fea453b10
--- /dev/null
+++ b/external/avro/benchmarks/AvroReadBenchmark-results.txt
@@ -0,0 +1,122 @@
+================================================================================================
+SQL Single Numeric Column Scan
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum 2774 / 2815 5.7 176.4 1.0X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum 2761 / 2777 5.7 175.5 1.0X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum 2783 / 2870 5.7 176.9 1.0X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum 3256 / 3266 4.8 207.0 1.0X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum 2841 / 2867 5.5 180.6 1.0X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum 2981 / 2996 5.3 189.5 1.0X
+
+
+================================================================================================
+Int and String Scan
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum of columns 4781 / 4783 2.2 456.0 1.0X
+
+
+================================================================================================
+Partitioned Table Scan
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Data column 3372 / 3386 4.7 214.4 1.0X
+Partition column 3035 / 3064 5.2 193.0 1.1X
+Both columns 3445 / 3461 4.6 219.1 1.0X
+
+
+================================================================================================
+Repeated String Scan
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum of string length 3395 / 3401 3.1 323.8 1.0X
+
+
+================================================================================================
+String with Nulls Scan
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum of string length 5580 / 5624 1.9 532.2 1.0X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum of string length 4622 / 4623 2.3 440.8 1.0X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum of string length 3238 / 3241 3.2 308.8 1.0X
+
+
+================================================================================================
+Single Column Scan From Wide Columns
+================================================================================================
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum of single column 5472 / 5484 0.2 5218.8 1.0X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Single Column Scan from 200 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum of single column 10680 / 10701 0.1 10185.1 1.0X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
+Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
+Single Column Scan from 300 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+Sum of single column 16143 / 16238 0.1 15394.9 1.0X
+
+
diff --git a/external/avro/pom.xml b/external/avro/pom.xml
index 9d8f319cc9396..ba6f20bfdbf58 100644
--- a/external/avro/pom.xml
+++ b/external/avro/pom.xml
@@ -20,12 +20,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-avro_2.11
+ spark-avro_2.12avro
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index 4fea2cb969446..207c54ce75f4c 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -138,7 +138,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("test NULL avro type") {
withTempPath { dir =>
val fields =
- Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava
+ Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[AnyVal])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
val datumWriter = new GenericDatumWriter[GenericRecord](schema)
@@ -161,7 +161,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val avroSchema: Schema = {
val union =
Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava)
- val fields = Seq(new Field("field1", union, "doc", null)).asJava
+ val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
schema
@@ -189,7 +189,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val avroSchema: Schema = {
val union =
Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava)
- val fields = Seq(new Field("field1", union, "doc", null)).asJava
+ val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
schema
@@ -221,7 +221,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
Schema.create(Type.NULL)
).asJava
)
- val fields = Seq(new Field("field1", union, "doc", null)).asJava
+ val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[AnyVal])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
schema
@@ -247,7 +247,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("Union of a single type") {
withTempPath { dir =>
val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava)
- val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava
+ val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[AnyVal])).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
@@ -274,10 +274,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val complexUnionType = Schema.createUnion(
List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava)
val fields = Seq(
- new Field("field1", complexUnionType, "doc", null),
- new Field("field2", complexUnionType, "doc", null),
- new Field("field3", complexUnionType, "doc", null),
- new Field("field4", complexUnionType, "doc", null)
+ new Field("field1", complexUnionType, "doc", null.asInstanceOf[AnyVal]),
+ new Field("field2", complexUnionType, "doc", null.asInstanceOf[AnyVal]),
+ new Field("field3", complexUnionType, "doc", null.asInstanceOf[AnyVal]),
+ new Field("field4", complexUnionType, "doc", null.asInstanceOf[AnyVal])
).asJava
val schema = Schema.createRecord("name", "docs", "namespace", false)
schema.setFields(fields)
@@ -508,7 +508,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val union2 = spark.read.format("avro").load(testAvro).select("union_float_double").collect()
assert(
union2
- .map(x => new java.lang.Double(x(0).toString))
+ .map(x => java.lang.Double.valueOf(x(0).toString))
.exists(p => Math.abs(p - Math.PI) < 0.001))
val fixed = spark.read.format("avro").load(testAvro).select("fixed3").collect()
@@ -941,7 +941,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val avroArrayType = resolveNullable(Schema.createArray(avroType), nullable)
val avroMapType = resolveNullable(Schema.createMap(avroType), nullable)
val name = "foo"
- val avroField = new Field(name, avroType, "", null)
+ val avroField = new Field(name, avroType, "", null.asInstanceOf[AnyVal])
val recordSchema = Schema.createRecord("name", "doc", "space", true, Seq(avroField).asJava)
val avroRecordType = resolveNullable(recordSchema, nullable)
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala
new file mode 100644
index 0000000000000..f2f7d650066fb
--- /dev/null
+++ b/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.benchmark
+
+import java.io.File
+
+import scala.util.Random
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.types._
+
+/**
+ * Benchmark to measure Avro read performance.
+ * {{{
+ * To run this benchmark:
+ * 1. without sbt: bin/spark-submit --class
+ * --jars ,,
+ * 2. build/sbt "avro/test:runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "avro/test:runMain "
+ * Results will be written to "benchmarks/AvroReadBenchmark-results.txt".
+ * }}}
+ */
+object AvroReadBenchmark extends SqlBasedBenchmark with SQLHelper {
+ def withTempTable(tableNames: String*)(f: => Unit): Unit = {
+ try f finally tableNames.foreach(spark.catalog.dropTempView)
+ }
+
+ private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = {
+ val dirAvro = dir.getCanonicalPath
+
+ if (partition.isDefined) {
+ df.write.partitionBy(partition.get).format("avro").save(dirAvro)
+ } else {
+ df.write.format("avro").save(dirAvro)
+ }
+
+ spark.read.format("avro").load(dirAvro).createOrReplaceTempView("avroTable")
+ }
+
+ def numericScanBenchmark(values: Int, dataType: DataType): Unit = {
+ val benchmark =
+ new Benchmark(s"SQL Single ${dataType.sql} Column Scan", values, output = output)
+
+ withTempPath { dir =>
+ withTempTable("t1", "avroTable") {
+ import spark.implicits._
+ spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1")
+
+ prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1"))
+
+ benchmark.addCase("Sum") { _ =>
+ spark.sql("SELECT sum(id) FROM avroTable").collect()
+ }
+
+ benchmark.run()
+ }
+ }
+ }
+
+ def intStringScanBenchmark(values: Int): Unit = {
+ val benchmark = new Benchmark("Int and String Scan", values, output = output)
+
+ withTempPath { dir =>
+ withTempTable("t1", "avroTable") {
+ import spark.implicits._
+ spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1")
+
+ prepareTable(
+ dir,
+ spark.sql("SELECT CAST(value AS INT) AS c1, CAST(value as STRING) AS c2 FROM t1"))
+
+ benchmark.addCase("Sum of columns") { _ =>
+ spark.sql("SELECT sum(c1), sum(length(c2)) FROM avroTable").collect()
+ }
+
+ benchmark.run()
+ }
+ }
+ }
+
+ def partitionTableScanBenchmark(values: Int): Unit = {
+ val benchmark = new Benchmark("Partitioned Table", values, output = output)
+
+ withTempPath { dir =>
+ withTempTable("t1", "avroTable") {
+ import spark.implicits._
+ spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1")
+
+ prepareTable(dir, spark.sql("SELECT value % 2 AS p, value AS id FROM t1"), Some("p"))
+
+ benchmark.addCase("Data column") { _ =>
+ spark.sql("SELECT sum(id) FROM avroTable").collect()
+ }
+
+ benchmark.addCase("Partition column") { _ =>
+ spark.sql("SELECT sum(p) FROM avroTable").collect()
+ }
+
+ benchmark.addCase("Both columns") { _ =>
+ spark.sql("SELECT sum(p), sum(id) FROM avroTable").collect()
+ }
+
+ benchmark.run()
+ }
+ }
+ }
+
+ def repeatedStringScanBenchmark(values: Int): Unit = {
+ val benchmark = new Benchmark("Repeated String", values, output = output)
+
+ withTempPath { dir =>
+ withTempTable("t1", "avroTable") {
+ spark.range(values).createOrReplaceTempView("t1")
+
+ prepareTable(dir, spark.sql("SELECT CAST((id % 200) + 10000 as STRING) AS c1 FROM t1"))
+
+ benchmark.addCase("Sum of string length") { _ =>
+ spark.sql("SELECT sum(length(c1)) FROM avroTable").collect()
+ }
+
+ benchmark.run()
+ }
+ }
+ }
+
+ def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = {
+ withTempPath { dir =>
+ withTempTable("t1", "avroTable") {
+ spark.range(values).createOrReplaceTempView("t1")
+
+ prepareTable(
+ dir,
+ spark.sql(
+ s"SELECT IF(RAND(1) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c1, " +
+ s"IF(RAND(2) < $fractionOfNulls, NULL, CAST(id as STRING)) AS c2 FROM t1"))
+
+ val percentageOfNulls = fractionOfNulls * 100
+ val benchmark =
+ new Benchmark(s"String with Nulls Scan ($percentageOfNulls%)", values, output = output)
+
+ benchmark.addCase("Sum of string length") { _ =>
+ spark.sql("SELECT SUM(LENGTH(c2)) FROM avroTable " +
+ "WHERE c1 IS NOT NULL AND c2 IS NOT NULL").collect()
+ }
+
+ benchmark.run()
+ }
+ }
+ }
+
+ def columnsBenchmark(values: Int, width: Int): Unit = {
+ val benchmark =
+ new Benchmark(s"Single Column Scan from $width columns", values, output = output)
+
+ withTempPath { dir =>
+ withTempTable("t1", "avroTable") {
+ import spark.implicits._
+ val middle = width / 2
+ val selectExpr = (1 to width).map(i => s"value as c$i")
+ spark.range(values).map(_ => Random.nextLong).toDF()
+ .selectExpr(selectExpr: _*).createOrReplaceTempView("t1")
+
+ prepareTable(dir, spark.sql("SELECT * FROM t1"))
+
+ benchmark.addCase("Sum of single column") { _ =>
+ spark.sql(s"SELECT sum(c$middle) FROM avroTable").collect()
+ }
+
+ benchmark.run()
+ }
+ }
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ runBenchmark("SQL Single Numeric Column Scan") {
+ Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType =>
+ numericScanBenchmark(1024 * 1024 * 15, dataType)
+ }
+ }
+ runBenchmark("Int and String Scan") {
+ intStringScanBenchmark(1024 * 1024 * 10)
+ }
+ runBenchmark("Partitioned Table Scan") {
+ partitionTableScanBenchmark(1024 * 1024 * 15)
+ }
+ runBenchmark("Repeated String Scan") {
+ repeatedStringScanBenchmark(1024 * 1024 * 10)
+ }
+ runBenchmark("String with Nulls Scan") {
+ for (fractionOfNulls <- List(0.0, 0.50, 0.95)) {
+ stringWithNullsScanBenchmark(1024 * 1024 * 10, fractionOfNulls)
+ }
+ }
+ runBenchmark("Single Column Scan From Wide Columns") {
+ columnsBenchmark(1024 * 1024 * 1, 100)
+ columnsBenchmark(1024 * 1024 * 1, 200)
+ columnsBenchmark(1024 * 1024 * 1, 300)
+ }
+ }
+}
diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml
index f24254b698080..b39db7540b7d2 100644
--- a/external/docker-integration-tests/pom.xml
+++ b/external/docker-integration-tests/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-docker-integration-tests_2.11
+ spark-docker-integration-tests_2.12jarSpark Project Docker Integration Testshttp://spark.apache.org/
diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml
index 4f9c3163b2408..f2dcf5d217a89 100644
--- a/external/kafka-0-10-assembly/pom.xml
+++ b/external/kafka-0-10-assembly/pom.xml
@@ -20,12 +20,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-streaming-kafka-0-10-assembly_2.11
+ spark-streaming-kafka-0-10-assembly_2.12jarSpark Integration for Kafka 0.10 Assemblyhttp://spark.apache.org/
diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml
index efd0862fb58ee..1af407167597b 100644
--- a/external/kafka-0-10-sql/pom.xml
+++ b/external/kafka-0-10-sql/pom.xml
@@ -20,17 +20,17 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xmlorg.apache.spark
- spark-sql-kafka-0-10_2.11
+ spark-sql-kafka-0-10_2.12sql-kafka-0-10
- 2.0.0
+ 2.1.0jarKafka 0.10+ Source for Structured Streaming
@@ -89,6 +89,13 @@
+
+
+ org.apache.zookeeper
+ zookeeper
+ 3.4.7
+ test
+ net.sf.jopt-simplejopt-simple
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index 28c9853bfea9c..f770f0c2a04c2 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -77,7 +77,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
- val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
+ val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams =
@@ -119,7 +119,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
- val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
+ val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams =
@@ -159,7 +159,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
- val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
+ val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams =
@@ -510,7 +510,7 @@ private[kafka010] object KafkaSourceProvider extends Logging {
.set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
// So that the driver does not pull too much data
- .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1))
+ .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, java.lang.Integer.valueOf(1))
// If buffer config is not set, set it to reasonable value to work around
// buffer issues (see KAFKA-3135)
@@ -538,6 +538,18 @@ private[kafka010] object KafkaSourceProvider extends Logging {
.setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
.build()
+ /**
+ * Returns a unique consumer group (group.id), allowing the user to set the prefix of
+ * the consumer group
+ */
+ private def streamingUniqueGroupId(
+ parameters: Map[String, String],
+ metadataPath: String): String = {
+ val groupIdPrefix = parameters
+ .getOrElse("groupIdPrefix", "spark-kafka-source")
+ s"${groupIdPrefix}-${UUID.randomUUID}-${metadataPath.hashCode}"
+ }
+
/** Class to conveniently update Kafka config params, while logging the changes */
private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) {
private val map = new ju.HashMap[String, Object](kafkaParams.asJava)
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
index fa6bdc20bd4f9..aa21f1271b817 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
@@ -56,7 +56,7 @@ trait KafkaContinuousTest extends KafkaSourceTest {
}
// Continuous processing tasks end asynchronously, so test that they actually end.
- private val tasksEndedListener = new SparkListener() {
+ private class TasksEndedListener extends SparkListener {
val activeTaskIdCount = new AtomicInteger(0)
override def onTaskStart(start: SparkListenerTaskStart): Unit = {
@@ -68,6 +68,8 @@ trait KafkaContinuousTest extends KafkaSourceTest {
}
}
+ private val tasksEndedListener = new TasksEndedListener()
+
override def beforeEach(): Unit = {
super.beforeEach()
spark.sparkContext.addSparkListener(tasksEndedListener)
diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml
index f59f07265a0f4..ea18b7e035915 100644
--- a/external/kafka-0-10/pom.xml
+++ b/external/kafka-0-10/pom.xml
@@ -20,16 +20,16 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-streaming-kafka-0-10_2.11
+ spark-streaming-kafka-0-10_2.12streaming-kafka-0-10
- 2.0.0
+ 2.1.0jarSpark Integration for Kafka 0.10
@@ -74,6 +74,13 @@
+
+
+ org.apache.zookeeper
+ zookeeper
+ 3.4.7
+ test
+ net.sf.jopt-simplejopt-simple
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala
index cf283a5c3e11e..07960d14b0bfc 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala
@@ -228,7 +228,7 @@ object ConsumerStrategies {
new Subscribe[K, V](
new ju.ArrayList(topics.asJavaCollection),
new ju.HashMap[String, Object](kafkaParams.asJava),
- new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava))
+ new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava))
}
/**
@@ -307,7 +307,7 @@ object ConsumerStrategies {
new SubscribePattern[K, V](
pattern,
new ju.HashMap[String, Object](kafkaParams.asJava),
- new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava))
+ new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava))
}
/**
@@ -391,7 +391,7 @@ object ConsumerStrategies {
new Assign[K, V](
new ju.ArrayList(topicPartitions.asJavaCollection),
new ju.HashMap[String, Object](kafkaParams.asJava),
- new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava))
+ new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(jl.Long.valueOf).asJava))
}
/**
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
index ba4009ef08856..224f41a683955 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
@@ -70,7 +70,7 @@ private[spark] class DirectKafkaInputDStream[K, V](
@transient private var kc: Consumer[K, V] = null
def consumer(): Consumer[K, V] = this.synchronized {
if (null == kc) {
- kc = consumerStrategy.onStart(currentOffsets.mapValues(l => new java.lang.Long(l)).asJava)
+ kc = consumerStrategy.onStart(currentOffsets.mapValues(l => java.lang.Long.valueOf(l)).asJava)
}
kc
}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala
index 64b6ef6c53b6d..2516b948f6650 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala
@@ -56,7 +56,7 @@ object KafkaUtils extends Logging {
): RDD[ConsumerRecord[K, V]] = {
val preferredHosts = locationStrategy match {
case PreferBrokers =>
- throw new AssertionError(
+ throw new IllegalArgumentException(
"If you want to prefer brokers, you must provide a mapping using PreferFixed " +
"A single KafkaRDD does not have a driver consumer and cannot look up brokers for you.")
case PreferConsistent => ju.Collections.emptyMap[TopicPartition, String]()
diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml
index 0bf4c265939e7..0ce922349ea66 100644
--- a/external/kinesis-asl-assembly/pom.xml
+++ b/external/kinesis-asl-assembly/pom.xml
@@ -20,12 +20,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-streaming-kinesis-asl-assembly_2.11
+ spark-streaming-kinesis-asl-assembly_2.12jarSpark Project Kinesis Assemblyhttp://spark.apache.org/
diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml
index 0aef25329db99..7d69764b77de7 100644
--- a/external/kinesis-asl/pom.xml
+++ b/external/kinesis-asl/pom.xml
@@ -19,13 +19,13 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-streaming-kinesis-asl_2.11
+ spark-streaming-kinesis-asl_2.12jarSpark Kinesis Integration
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
index 1ffec01df9f00..d4a428f45c110 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
@@ -22,7 +22,7 @@ import scala.reflect.ClassTag
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.model.Record
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.streaming.{Duration, StreamingContext, Time}
@@ -84,14 +84,14 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
}
}
-@InterfaceStability.Evolving
+@Evolving
object KinesisInputDStream {
/**
* Builder for [[KinesisInputDStream]] instances.
*
* @since 2.2.0
*/
- @InterfaceStability.Evolving
+ @Evolving
class Builder {
// Required params
private var streamingContext: Option[StreamingContext] = None
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala
index 9facfe8ff2b0f..dcb60b21d9851 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala
@@ -14,13 +14,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.streaming.kinesis
-import scala.collection.JavaConverters._
+package org.apache.spark.streaming.kinesis
import com.amazonaws.auth._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
import org.apache.spark.internal.Logging
/**
@@ -84,14 +83,14 @@ private[kinesis] final case class STSCredentials(
}
}
-@InterfaceStability.Evolving
+@Evolving
object SparkAWSCredentials {
/**
* Builder for [[SparkAWSCredentials]] instances.
*
* @since 2.2.0
*/
- @InterfaceStability.Evolving
+ @Evolving
class Builder {
private var basicCreds: Option[BasicCredentials] = None
private var stsCreds: Option[STSCredentials] = None
diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml
index 35a55b70baf33..a23d255f9187c 100644
--- a/external/spark-ganglia-lgpl/pom.xml
+++ b/external/spark-ganglia-lgpl/pom.xml
@@ -19,13 +19,13 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-ganglia-lgpl_2.11
+ spark-ganglia-lgpl_2.12jarSpark Ganglia Integration
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 22bc148e068a5..444568a03d6c7 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -20,12 +20,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../pom.xml
- spark-graphx_2.11
+ spark-graphx_2.12graphx
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
index a4e293d74a012..184b96426fa9b 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
@@ -117,13 +117,11 @@ class ShippableVertexPartition[VD: ClassTag](
val initialSize = if (shipSrc && shipDst) routingTable.partitionSize(pid) else 64
val vids = new PrimitiveVector[VertexId](initialSize)
val attrs = new PrimitiveVector[VD](initialSize)
- var i = 0
routingTable.foreachWithinEdgePartition(pid, shipSrc, shipDst) { vid =>
if (isDefined(vid)) {
vids += vid
attrs += this(vid)
}
- i += 1
}
(pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array))
}
@@ -137,12 +135,10 @@ class ShippableVertexPartition[VD: ClassTag](
def shipVertexIds(): Iterator[(PartitionID, Array[VertexId])] = {
Iterator.tabulate(routingTable.numEdgePartitions) { pid =>
val vids = new PrimitiveVector[VertexId](routingTable.partitionSize(pid))
- var i = 0
routingTable.foreachWithinEdgePartition(pid, true, true) { vid =>
if (isDefined(vid)) {
vids += vid
}
- i += 1
}
(pid, vids.trim().array)
}
diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml
index 5d08d8d92e577..fc98f2e0023ce 100644
--- a/hadoop-cloud/pom.xml
+++ b/hadoop-cloud/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../pom.xml
- spark-hadoop-cloud_2.11
+ spark-hadoop-cloud_2.12jarSpark Project Cloud Integration
diff --git a/launcher/pom.xml b/launcher/pom.xml
index a833a35399918..130519d6c3b08 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../pom.xml
- spark-launcher_2.11
+ spark-launcher_2.12jarSpark Project Launcherhttp://spark.apache.org/
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java
index 9cbebdaeb33d3..0999cbd216871 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractAppHandle.java
@@ -31,8 +31,8 @@ abstract class AbstractAppHandle implements SparkAppHandle {
private final LauncherServer server;
private LauncherServer.ServerConnection connection;
- private List listeners;
- private AtomicReference state;
+ private List listeners;
+ private AtomicReference state;
private volatile String appId;
private volatile boolean disposed;
@@ -42,7 +42,7 @@ protected AbstractAppHandle(LauncherServer server) {
}
@Override
- public synchronized void addListener(Listener l) {
+ public synchronized void addListener(SparkAppHandle.Listener l) {
if (listeners == null) {
listeners = new CopyOnWriteArrayList<>();
}
@@ -50,7 +50,7 @@ public synchronized void addListener(Listener l) {
}
@Override
- public State getState() {
+ public SparkAppHandle.State getState() {
return state.get();
}
@@ -120,11 +120,11 @@ synchronized void dispose() {
}
}
- void setState(State s) {
+ void setState(SparkAppHandle.State s) {
setState(s, false);
}
- void setState(State s, boolean force) {
+ void setState(SparkAppHandle.State s, boolean force) {
if (force) {
state.set(s);
fireEvent(false);
diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml
index ec5f9b0e92c8f..2eab868ac0dc8 100644
--- a/mllib-local/pom.xml
+++ b/mllib-local/pom.xml
@@ -20,12 +20,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../pom.xml
- spark-mllib-local_2.11
+ spark-mllib-local_2.12mllib-local
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
index 5824e463ca1aa..6e950f968a65d 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
@@ -106,7 +106,7 @@ sealed trait Vector extends Serializable {
*/
@Since("2.0.0")
def copy: Vector = {
- throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.")
+ throw new UnsupportedOperationException(s"copy is not implemented for ${this.getClass}.")
}
/**
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
index 3167e0c286d47..e7f7a8e07d7f2 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
@@ -48,14 +48,14 @@ class MultivariateGaussian @Since("2.0.0") (
this(Vectors.fromBreeze(mean), Matrices.fromBreeze(cov))
}
- private val breezeMu = mean.asBreeze.toDenseVector
+ @transient private lazy val breezeMu = mean.asBreeze.toDenseVector
/**
* Compute distribution dependent constants:
* rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t
* u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
*/
- private val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants
+ @transient private lazy val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants
/**
* Returns density of this multivariate Gaussian at given point, x
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala
index ace44165b1067..332734bd28341 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala
@@ -862,10 +862,10 @@ class MatricesSuite extends SparkMLFunSuite {
mat.toString(0, 0)
mat.toString(Int.MinValue, Int.MinValue)
mat.toString(Int.MaxValue, Int.MaxValue)
- var lines = mat.toString(6, 50).lines.toArray
+ var lines = mat.toString(6, 50).split('\n')
assert(lines.size == 5 && lines.forall(_.size <= 50))
- lines = mat.toString(5, 100).lines.toArray
+ lines = mat.toString(5, 100).split('\n')
assert(lines.size == 5 && lines.forall(_.size <= 100))
}
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 17ddb87c4d86a..0b17345064a71 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -20,12 +20,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../pom.xml
- spark-mllib_2.11
+ spark-mllib_2.12mllib
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 62cfa39746ff0..2c4186a13d8f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -180,7 +180,6 @@ class GBTClassifier @Since("1.4.0") (
(convert2LabeledPoint(dataset), null)
}
- val numFeatures = trainDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val numClasses = 2
@@ -196,7 +195,6 @@ class GBTClassifier @Since("1.4.0") (
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
validationIndicatorCol)
- instr.logNumFeatures(numFeatures)
instr.logNumClasses(numClasses)
val (baseLearners, learnerWeights) = if (withValidation) {
@@ -206,6 +204,9 @@ class GBTClassifier @Since("1.4.0") (
GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy))
}
+ val numFeatures = baseLearners.head.numFeatures
+ instr.logNumFeatures(numFeatures)
+
new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
}
@@ -427,7 +428,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
s" trees based on metadata but found ${trees.length} trees.")
val model = new GBTClassificationModel(metadata.uid,
trees, treeWeights, numFeatures)
- metadata.getAndSetParams(model)
+ // We ignore the impurity while loading models because in previous models it was wrongly
+ // set to gini (see SPARK-25959).
+ metadata.getAndSetParams(model, Some(List("impurity")))
model
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 51495c1a74e69..1a7a5e7a52344 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -146,7 +146,7 @@ class NaiveBayes @Since("1.5.0") (
requireZeroOneBernoulliValues
case _ =>
// This should never happen.
- throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
+ throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
}
@@ -196,7 +196,7 @@ class NaiveBayes @Since("1.5.0") (
case Bernoulli => math.log(n + 2.0 * lambda)
case _ =>
// This should never happen.
- throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
+ throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
var j = 0
while (j < numFeatures) {
@@ -295,7 +295,7 @@ class NaiveBayesModel private[ml] (
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
case _ =>
// This should never happen.
- throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
+ throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
@Since("1.6.0")
@@ -329,7 +329,7 @@ class NaiveBayesModel private[ml] (
bernoulliCalculation(features)
case _ =>
// This should never happen.
- throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
+ throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 57132381b6474..7598a28b6f89d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -142,7 +142,7 @@ class RandomForestClassifier @Since("1.4.0") (
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])
- val numFeatures = oldDataset.first().features.size
+ val numFeatures = trees.head.numFeatures
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)
new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 498310d6644e1..5d02305aafdda 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -279,7 +279,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
/**
* K-means clustering with support for k-means|| initialization proposed by Bahmani et al.
*
- * @see Bahmani et al., Scalable k-means++.
+ * @see Bahmani et al., Scalable k-means++.
*/
@Since("1.5.0")
class KMeans @Since("1.5.0") (
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 794b1e7d9d881..f1602c1bc5333 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
+import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Dataset, Row}
@@ -33,7 +33,8 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.5.0")
@Experimental
class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String)
- extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
+ extends Evaluator with HasPredictionCol with HasLabelCol
+ with HasWeightCol with DefaultParamsWritable {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("mcEval"))
@@ -67,6 +68,10 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
@Since("1.5.0")
def setLabelCol(value: String): this.type = set(labelCol, value)
+ /** @group setParam */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
setDefault(metricName -> "f1")
@Since("2.0.0")
@@ -75,11 +80,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(labelCol))
- val predictionAndLabels =
- dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
- case Row(prediction: Double, label: Double) => (prediction, label)
+ val predictionAndLabelsWithWeights =
+ dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType),
+ if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
+ .rdd.map {
+ case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight)
}
- val metrics = new MulticlassMetrics(predictionAndLabels)
+ val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights)
val metric = $(metricName) match {
case "f1" => metrics.weightedFMeasure
case "weightedPrecision" => metrics.weightedPrecision
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index f99649f7fa164..0b989b0d7d253 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -89,7 +89,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
def setOutputCol(value: String): this.type = set(outputCol, value)
/**
- * Param for how to handle invalid entries. Options are 'skip' (filter out rows with
+ * Param for how to handle invalid entries containing NaN values. Values outside the splits
+ * will always be treated as errors. Options are 'skip' (filter out rows with
* invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special
* additional bucket). Note that in the multiple column case, the invalid handling is applied
* to all columns. That said for 'error' it will throw an error if any invalids are found in
@@ -99,7 +100,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
*/
@Since("2.1.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
- "how to handle invalid entries. Options are skip (filter out rows with invalid values), " +
+ "how to handle invalid entries containing NaN values. Values outside the splits will always " +
+ "be treated as errorsOptions are skip (filter out rows with invalid values), " +
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
ParamValidators.inArray(Bucketizer.supportedHandleInvalids))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
index c5d0ec1a8d350..412954f7b2d5a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
@@ -17,8 +17,6 @@
package org.apache.spark.ml.feature
-import scala.beans.BeanInfo
-
import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.Vector
@@ -30,8 +28,12 @@ import org.apache.spark.ml.linalg.Vector
* @param features List of features for this data point.
*/
@Since("2.0.0")
-@BeanInfo
case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features: Vector) {
+
+ def getLabel: Double = label
+
+ def getFeatures: Vector = features
+
override def toString: String = {
s"($label,$features)"
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 56e2c543d100a..5bfaa3b7f3f52 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -17,10 +17,6 @@
package org.apache.spark.ml.feature
-import org.json4s.JsonDSL._
-import org.json4s.JValue
-import org.json4s.jackson.JsonMethods._
-
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
@@ -209,7 +205,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
if (isSet(inputCols)) {
val splitsArray = if (isSet(numBucketsArray)) {
val probArrayPerCol = $(numBucketsArray).map { numOfBuckets =>
- (0.0 to 1.0 by 1.0 / numOfBuckets).toArray
+ (0 to numOfBuckets).map(_.toDouble / numOfBuckets).toArray
}
val probabilityArray = probArrayPerCol.flatten.sorted.distinct
@@ -229,12 +225,12 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
}
} else {
dataset.stat.approxQuantile($(inputCols),
- (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
+ (0 to $(numBuckets)).map(_.toDouble / $(numBuckets)).toArray, $(relativeError))
}
bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits))
} else {
val splits = dataset.stat.approxQuantile($(inputCol),
- (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError))
+ (0 to $(numBuckets)).map(_.toDouble / $(numBuckets)).toArray, $(relativeError))
bucketizer.setSplits(getDistinctSplits(splits))
}
copyValues(bucketizer.setParent(this))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 840a89b76d26b..7322815c12ab8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -118,10 +118,10 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol {
/**
* :: Experimental ::
* A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in
- * Li et al., PFP: Parallel FP-Growth for Query
+ * Li et al., PFP: Parallel FP-Growth for Query
* Recommendation. PFP distributes computation in such a way that each worker executes an
* independent group of mining tasks. The FP-Growth algorithm is described in
- * Han et al., Mining frequent patterns without
+ * Han et al., Mining frequent patterns without
* candidate generation. Note null values in the itemsCol column are ignored during fit().
*
* @see
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
index bd1c1a8885201..2a3413553a6af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
* A parallel PrefixSpan algorithm to mine frequent sequential patterns.
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
- * (see here).
+ * (see here).
* This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
* run the PrefixSpan algorithm.
*
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index e6c347ed17c15..4c50f1e3292bc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -97,7 +97,7 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
case m: Matrix =>
JsonMatrixConverter.toJson(m)
case _ =>
- throw new NotImplementedError(
+ throw new UnsupportedOperationException(
"The default jsonEncode only supports string, vector and matrix. " +
s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.")
}
@@ -151,7 +151,7 @@ private[ml] object Param {
}
case _ =>
- throw new NotImplementedError(
+ throw new UnsupportedOperationException(
"The default jsonDecode only supports string, vector and matrix. " +
s"${this.getClass.getName} must override jsonDecode to support its value type.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index ffe592789b3cc..50ef4330ddc80 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -557,7 +557,7 @@ object ALSModel extends MLReadable[ALSModel] {
*
* For implicit preference data, the algorithm used is based on
* "Collaborative Filtering for Implicit Feedback Datasets", available at
- * http://dx.doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here.
+ * https://doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here.
*
* Essentially instead of finding the low-rank approximations to the rating matrix `R`,
* this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 6fa656275c1fd..c9de85de42fa5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -145,7 +145,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("1.4.0")
object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] {
/** Accessor for supported impurities: variance */
- final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+ final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities
@Since("2.0.0")
override def load(path: String): DecisionTreeRegressor = super.load(path)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 07f88d8d5f84d..88dee2507bf7e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -166,7 +166,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
} else {
(extractLabeledPoints(dataset), null)
}
- val numFeatures = trainDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
instr.logPipelineStage(this)
@@ -174,7 +173,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
- instr.logNumFeatures(numFeatures)
val (baseLearners, learnerWeights) = if (withValidation) {
GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
@@ -183,6 +181,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
GradientBoostedTrees.run(trainDataset, boostingStrategy,
$(seed), $(featureSubsetStrategy))
}
+
+ val numFeatures = baseLearners.head.numFeatures
+ instr.logNumFeatures(numFeatures)
+
new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 82bf66ff66d8a..a548ec537bb44 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -133,7 +133,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])
- val numFeatures = oldDataset.first().features.size
+ val numFeatures = trees.head.numFeatures
instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures)
new RandomForestRegressionModel(uid, trees, numFeatures)
}
@@ -146,7 +146,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
- final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+ final val supportedImpurities: Array[String] = HasVarianceImpurity.supportedImpurities
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
@Since("1.4.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 00157fe63af41..f1e3836ebe476 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -258,11 +258,7 @@ private[ml] object TreeClassifierParams {
private[ml] trait DecisionTreeClassifierParams
extends DecisionTreeParams with TreeClassifierParams
-/**
- * Parameters for Decision Tree-based regression algorithms.
- */
-private[ml] trait TreeRegressorParams extends Params {
-
+private[ml] trait HasVarianceImpurity extends Params {
/**
* Criterion used for information gain calculation (case-insensitive).
* Supported: "variance".
@@ -271,9 +267,9 @@ private[ml] trait TreeRegressorParams extends Params {
*/
final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
- s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
+ s" ${HasVarianceImpurity.supportedImpurities.mkString(", ")}",
(value: String) =>
- TreeRegressorParams.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
+ HasVarianceImpurity.supportedImpurities.contains(value.toLowerCase(Locale.ROOT)))
setDefault(impurity -> "variance")
@@ -299,12 +295,17 @@ private[ml] trait TreeRegressorParams extends Params {
}
}
-private[ml] object TreeRegressorParams {
+private[ml] object HasVarianceImpurity {
// These options should be lowercase.
final val supportedImpurities: Array[String] =
Array("variance").map(_.toLowerCase(Locale.ROOT))
}
+/**
+ * Parameters for Decision Tree-based regression algorithms.
+ */
+private[ml] trait TreeRegressorParams extends HasVarianceImpurity
+
private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams
with TreeRegressorParams with HasVarianceCol {
@@ -538,7 +539,7 @@ private[ml] object GBTClassifierParams {
Array("logistic").map(_.toLowerCase(Locale.ROOT))
}
-private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams {
+private[ml] trait GBTClassifierParams extends GBTParams with HasVarianceImpurity {
/**
* Loss function which GBT tries to minimize. (case-insensitive)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 135828815504a..6d46ea0adcc9a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -140,8 +140,8 @@ private[ml] object ValidatorParams {
"value" -> compact(render(JString(relativePath))),
"isJson" -> compact(render(JBool(false))))
case _: MLWritable =>
- throw new NotImplementedError("ValidatorParams.saveImpl does not handle parameters " +
- "of type: MLWritable that are not DefaultParamsWritable")
+ throw new UnsupportedOperationException("ValidatorParams.saveImpl does not handle" +
+ " parameters of type: MLWritable that are not DefaultParamsWritable")
case _ =>
Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v),
"isJson" -> compact(render(JBool(true))))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index a0ac26a34d8c8..fbc7be25a5640 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -31,7 +31,7 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{SparkContext, SparkException}
-import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
+import org.apache.spark.annotation.{DeveloperApi, Since, Unstable}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
@@ -84,7 +84,7 @@ private[util] sealed trait BaseReadWrite {
*
* @since 2.4.0
*/
-@InterfaceStability.Unstable
+@Unstable
@Since("2.4.0")
trait MLWriterFormat {
/**
@@ -108,7 +108,7 @@ trait MLWriterFormat {
*
* @since 2.4.0
*/
-@InterfaceStability.Unstable
+@Unstable
@Since("2.4.0")
trait MLFormatRegister extends MLWriterFormat {
/**
@@ -208,7 +208,7 @@ abstract class MLWriter extends BaseReadWrite with Logging {
/**
* A ML Writer which delegates based on the requested format.
*/
-@InterfaceStability.Unstable
+@Unstable
@Since("2.4.0")
class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
private var source: String = "internal"
@@ -256,7 +256,7 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
s"Multiple writers found for $source+$stageName, try using the class name of the writer")
}
if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
- val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
+ val writer = writerCls.getConstructor().newInstance().asInstanceOf[MLWriterFormat]
writer.write(path, sparkSession, optionMap, stage)
} else {
throw new SparkException(s"ML source $source is not a valid MLWriterFormat")
@@ -291,7 +291,7 @@ trait MLWritable {
* Trait for classes that provide `GeneralMLWriter`.
*/
@Since("2.4.0")
-@InterfaceStability.Unstable
+@Unstable
trait GeneralMLWritable extends MLWritable {
/**
* Returns an `MLWriter` instance for this ML instance.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 9e8774732efe6..16ba6cabdc823 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -83,7 +83,7 @@ class NaiveBayesModel private[spark] (
(Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
case _ =>
// This should never happen.
- throw new UnknownError(s"Invalid modelType: $modelType.")
+ throw new IllegalArgumentException(s"Invalid modelType: $modelType.")
}
@Since("1.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
index 980e0c92531a2..ad83c24ede964 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
@@ -27,10 +27,19 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for multiclass classification.
*
- * @param predictionAndLabels an RDD of (prediction, label) pairs.
+ * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or
+ * (prediction, label) pairs.
*/
@Since("1.1.0")
-class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) {
+class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) {
+ val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map {
+ case (prediction: Double, label: Double, weight: Double) =>
+ (prediction, label, weight)
+ case (prediction: Double, label: Double) =>
+ (prediction, label, 1.0)
+ case other =>
+ throw new IllegalArgumentException(s"Expected tuples, got $other")
+ }
/**
* An auxiliary constructor taking a DataFrame.
@@ -39,21 +48,29 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
- private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
- private lazy val labelCount: Long = labelCountByClass.values.sum
- private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
- .map { case (prediction, label) =>
- (label, if (label == prediction) 1 else 0)
+ private lazy val labelCountByClass: Map[Double, Double] =
+ predLabelsWeight.map {
+ case (_: Double, label: Double, weight: Double) =>
+ (label, weight)
+ }.reduceByKey(_ + _)
+ .collectAsMap()
+ private lazy val labelCount: Double = labelCountByClass.values.sum
+ private lazy val tpByClass: Map[Double, Double] = predLabelsWeight
+ .map {
+ case (prediction: Double, label: Double, weight: Double) =>
+ (label, if (label == prediction) weight else 0.0)
}.reduceByKey(_ + _)
.collectAsMap()
- private lazy val fpByClass: Map[Double, Int] = predictionAndLabels
- .map { case (prediction, label) =>
- (prediction, if (prediction != label) 1 else 0)
+ private lazy val fpByClass: Map[Double, Double] = predLabelsWeight
+ .map {
+ case (prediction: Double, label: Double, weight: Double) =>
+ (prediction, if (prediction != label) weight else 0.0)
}.reduceByKey(_ + _)
.collectAsMap()
- private lazy val confusions = predictionAndLabels
- .map { case (prediction, label) =>
- ((label, prediction), 1)
+ private lazy val confusions = predLabelsWeight
+ .map {
+ case (prediction: Double, label: Double, weight: Double) =>
+ ((label, prediction), weight)
}.reduceByKey(_ + _)
.collectAsMap()
@@ -71,7 +88,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
while (i < n) {
var j = 0
while (j < n) {
- values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble
+ values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0.0)
j += 1
}
i += 1
@@ -92,8 +109,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
*/
@Since("1.1.0")
def falsePositiveRate(label: Double): Double = {
- val fp = fpByClass.getOrElse(label, 0)
- fp.toDouble / (labelCount - labelCountByClass(label))
+ val fp = fpByClass.getOrElse(label, 0.0)
+ fp / (labelCount - labelCountByClass(label))
}
/**
@@ -103,7 +120,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
@Since("1.1.0")
def precision(label: Double): Double = {
val tp = tpByClass(label)
- val fp = fpByClass.getOrElse(label, 0)
+ val fp = fpByClass.getOrElse(label, 0.0)
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
}
@@ -112,7 +129,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* @param label the label.
*/
@Since("1.1.0")
- def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label)
+ def recall(label: Double): Double = tpByClass(label) / labelCountByClass(label)
/**
* Returns f-measure for a given label (category)
@@ -140,7 +157,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* out of the total number of instances.)
*/
@Since("2.0.0")
- lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount
+ lazy val accuracy: Double = tpByClass.values.sum / labelCount
/**
* Returns weighted true positive rate
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 3a1bc35186dc3..519c1ea47c1db 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -152,10 +152,10 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] {
/**
* A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in
- * Li et al., PFP: Parallel FP-Growth for Query
+ * Li et al., PFP: Parallel FP-Growth for Query
* Recommendation. PFP distributes computation in such a way that each worker executes an
* independent group of mining tasks. The FP-Growth algorithm is described in
- * Han et al., Mining frequent patterns without
+ * Han et al., Mining frequent patterns without
* candidate generation.
*
* @param minSupport the minimal support level of the frequent pattern, any pattern that appears
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 7aed2f3bd8a61..b2c09b408b40b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -45,7 +45,7 @@ import org.apache.spark.storage.StorageLevel
* A parallel PrefixSpan algorithm to mine frequent sequential patterns.
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
- * (see here).
+ * (see here).
*
* @param minSupport the minimal support level of the sequential pattern, any pattern that appears
* more than (minSupport * size-of-the-dataset) times will be output
@@ -174,6 +174,13 @@ class PrefixSpan private (
val freqSequences = results.map { case (seq: Array[Int], count: Long) =>
new FreqSequence(toPublicRepr(seq), count)
}
+ // Cache the final RDD to the same storage level as input
+ if (data.getStorageLevel != StorageLevel.NONE) {
+ freqSequences.persist(data.getStorageLevel)
+ freqSequences.count()
+ }
+ dataInternalRepr.unpersist(false)
+
new PrefixSpanModel(freqSequences)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6e68d9684a672..9cdf1944329b8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -117,7 +117,7 @@ sealed trait Vector extends Serializable {
*/
@Since("1.1.0")
def copy: Vector = {
- throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.")
+ throw new UnsupportedOperationException(s"copy is not implemented for ${this.getClass}.")
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index 7caacd13b3459..e58860fea97d0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -17,10 +17,9 @@
package org.apache.spark.mllib.linalg.distributed
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM}
import scala.collection.mutable.ArrayBuffer
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM, SparseVector => BSV, Vector => BV}
-
import org.apache.spark.{Partitioner, SparkException}
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
@@ -28,6 +27,7 @@ import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+
/**
* A grid partitioner, which uses a regular grid to partition coordinates.
*
@@ -273,24 +273,37 @@ class BlockMatrix @Since("1.3.0") (
require(cols < Int.MaxValue, s"The number of columns should be less than Int.MaxValue ($cols).")
val rows = blocks.flatMap { case ((blockRowIdx, blockColIdx), mat) =>
- mat.rowIter.zipWithIndex.map {
+ mat.rowIter.zipWithIndex.filter(_._1.size > 0).map {
case (vector, rowIdx) =>
- blockRowIdx * rowsPerBlock + rowIdx -> ((blockColIdx, vector.asBreeze))
+ blockRowIdx * rowsPerBlock + rowIdx -> ((blockColIdx, vector))
}
}.groupByKey().map { case (rowIdx, vectors) =>
- val numberNonZeroPerRow = vectors.map(_._2.activeSize).sum.toDouble / cols.toDouble
-
- val wholeVector = if (numberNonZeroPerRow <= 0.1) { // Sparse at 1/10th nnz
- BSV.zeros[Double](cols)
- } else {
- BDV.zeros[Double](cols)
- }
+ val numberNonZero = vectors.map(_._2.numActives).sum
+ val numberNonZeroPerRow = numberNonZero.toDouble / cols.toDouble
+
+ val wholeVector =
+ if (numberNonZeroPerRow <= 0.1) { // Sparse at 1/10th nnz
+ val arrBufferIndices = new ArrayBuffer[Int](numberNonZero)
+ val arrBufferValues = new ArrayBuffer[Double](numberNonZero)
+
+ vectors.foreach { case (blockColIdx: Int, vec: Vector) =>
+ val offset = colsPerBlock * blockColIdx
+ vec.foreachActive { case (colIdx: Int, value: Double) =>
+ arrBufferIndices += offset + colIdx
+ arrBufferValues += value
+ }
+ }
+ Vectors.sparse(cols, arrBufferIndices.toArray, arrBufferValues.toArray)
+ } else {
+ val wholeVectorBuf = BDV.zeros[Double](cols)
+ vectors.foreach { case (blockColIdx: Int, vec: Vector) =>
+ val offset = colsPerBlock * blockColIdx
+ wholeVectorBuf(offset until Math.min(cols, offset + colsPerBlock)) := vec.asBreeze
+ }
+ Vectors.fromBreeze(wholeVectorBuf)
+ }
- vectors.foreach { case (blockColIdx: Int, vec: BV[_]) =>
- val offset = colsPerBlock * blockColIdx
- wholeVector(offset until Math.min(cols, offset + colsPerBlock)) := vec
- }
- new IndexedRow(rowIdx, Vectors.fromBreeze(wholeVector))
+ IndexedRow(rowIdx, wholeVector)
}
new IndexedRowMatrix(rows)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 82ab716ed96a8..c12b751bfb8e4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -540,7 +540,7 @@ class RowMatrix @Since("1.0.0") (
* decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape.
* Reference:
* Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce
- * architectures" (see here)
+ * architectures" (see here)
*
* @param computeQ whether to computeQ
* @return QRDecomposition(Q, R), Q = null if computeQ = false.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 14288221b6945..12870f819b147 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -54,7 +54,7 @@ case class Rating @Since("0.8.0") (
*
* For implicit preference data, the algorithm used is based on
* "Collaborative Filtering for Implicit Feedback Datasets", available at
- * here, adapted for the blocked approach
+ * here, adapted for the blocked approach
* used here.
*
* Essentially instead of finding the low-rank approximations to the rating matrix `R`,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
index 4381d6ab20cc0..b320057b25276 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
@@ -17,8 +17,6 @@
package org.apache.spark.mllib.regression
-import scala.beans.BeanInfo
-
import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint}
@@ -32,10 +30,14 @@ import org.apache.spark.mllib.util.NumericParser
* @param features List of features for this data point.
*/
@Since("0.8.0")
-@BeanInfo
case class LabeledPoint @Since("1.0.0") (
@Since("0.8.0") label: Double,
@Since("1.0.0") features: Vector) {
+
+ def getLabel: Double = label
+
+ def getFeatures: Vector = features
+
override def toString: String = {
s"($label,$features)"
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index 4cf662e036346..9a746dcf35556 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -43,7 +43,7 @@ class MultivariateGaussian @Since("1.3.0") (
require(sigma.numCols == sigma.numRows, "Covariance matrix must be square")
require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size")
- private val breezeMu = mu.asBreeze.toDenseVector
+ @transient private lazy val breezeMu = mu.asBreeze.toDenseVector
/**
* private[mllib] constructor
@@ -60,7 +60,7 @@ class MultivariateGaussian @Since("1.3.0") (
* rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t
* u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
*/
- private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
+ @transient private lazy val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
/**
* Returns density of this multivariate Gaussian at given point, x
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
index 80c6ef0ea1aa1..85ed11d6553d9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
@@ -17,8 +17,6 @@
package org.apache.spark.mllib.stat.test
-import scala.beans.BeanInfo
-
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.streaming.api.java.JavaDStream
@@ -32,10 +30,11 @@ import org.apache.spark.util.StatCounter
* @param value numeric value of the observation.
*/
@Since("1.6.0")
-@BeanInfo
case class BinarySample @Since("1.6.0") (
@Since("1.6.0") isExperiment: Boolean,
@Since("1.6.0") value: Double) {
+ def getIsExperiment: Boolean = isExperiment
+ def getValue: Double = value
override def toString: String = {
s"($isExperiment, $value)"
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 14af8b5c73870..6d15a6bb01e4e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -506,8 +506,6 @@ object MLUtils extends Logging {
val n = v1.size
require(v2.size == n)
require(norm1 >= 0.0 && norm2 >= 0.0)
- val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
- val normDiff = norm1 - norm2
var sqDist = 0.0
/*
* The relative error is
@@ -521,19 +519,23 @@ object MLUtils extends Logging {
* The bound doesn't need the inner product, so we can use it as a sufficient condition to
* check quickly whether the inner product approach is accurate.
*/
- val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
- if (precisionBound1 < precision) {
- sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)
- } else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {
- val dotValue = dot(v1, v2)
- sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
- val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) /
- (sqDist + EPSILON)
- if (precisionBound2 > precision) {
- sqDist = Vectors.sqdist(v1, v2)
- }
- } else {
+ if (v1.isInstanceOf[DenseVector] && v2.isInstanceOf[DenseVector]) {
sqDist = Vectors.sqdist(v1, v2)
+ } else {
+ val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
+ val normDiff = norm1 - norm2
+ val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
+ if (precisionBound1 < precision) {
+ sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)
+ } else {
+ val dotValue = dot(v1, v2)
+ sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
+ val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) /
+ (sqDist + EPSILON)
+ if (precisionBound2 > precision) {
+ sqDist = Vectors.sqdist(v1, v2)
+ }
+ }
}
sqDist
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
index ec45e32d412a9..dff00eade620f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
@@ -73,7 +73,7 @@ object PredictorSuite {
}
override def copy(extra: ParamMap): MockPredictor =
- throw new NotImplementedError()
+ throw new UnsupportedOperationException()
}
class MockPredictionModel(override val uid: String)
@@ -82,9 +82,9 @@ object PredictorSuite {
def this() = this(Identifiable.randomUID("mockpredictormodel"))
override def predict(features: Vector): Double =
- throw new NotImplementedError()
+ throw new UnsupportedOperationException()
override def copy(extra: ParamMap): MockPredictionModel =
- throw new NotImplementedError()
+ throw new UnsupportedOperationException()
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 6355e0f179496..eb5f3ca45940d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.ml.attribute
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.types._
class AttributeSuite extends SparkFunSuite {
@@ -221,4 +222,20 @@ class AttributeSuite extends SparkFunSuite {
val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata)
assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
}
+
+ test("Kryo class register") {
+ val conf = new SparkConf(false)
+ conf.set("spark.kryo.registrationRequired", "true")
+
+ val ser = new KryoSerializer(conf).newInstance()
+
+ val numericAttr = new NumericAttribute(Some("numeric"), Some(1), Some(1.0), Some(2.0))
+ val nominalAttr = new NominalAttribute(Some("nominal"), Some(2), Some(false))
+ val binaryAttr = new BinaryAttribute(Some("binary"), Some(3), Some(Array("i", "j")))
+
+ Seq(numericAttr, nominalAttr, binaryAttr).foreach { i =>
+ val i2 = ser.deserialize[Attribute](ser.serialize(i))
+ assert(i === i2)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
index 87bf2be06c2be..be52d99e54d3b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
@@ -117,10 +117,10 @@ object ClassifierSuite {
def this() = this(Identifiable.randomUID("mockclassifier"))
- override def copy(extra: ParamMap): MockClassifier = throw new NotImplementedError()
+ override def copy(extra: ParamMap): MockClassifier = throw new UnsupportedOperationException()
override def train(dataset: Dataset[_]): MockClassificationModel =
- throw new NotImplementedError()
+ throw new UnsupportedOperationException()
// Make methods public
override def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] =
@@ -133,11 +133,12 @@ object ClassifierSuite {
def this() = this(Identifiable.randomUID("mockclassificationmodel"))
- protected def predictRaw(features: Vector): Vector = throw new NotImplementedError()
+ protected def predictRaw(features: Vector): Vector = throw new UnsupportedOperationException()
- override def copy(extra: ParamMap): MockClassificationModel = throw new NotImplementedError()
+ override def copy(extra: ParamMap): MockClassificationModel =
+ throw new UnsupportedOperationException()
- override def numClasses: Int = throw new NotImplementedError()
+ override def numClasses: Int = throw new UnsupportedOperationException()
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 304977634189c..cedbaf1858ef4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -448,6 +448,7 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
model2: GBTClassificationModel): Unit = {
TreeTests.checkEqual(model, model2)
assert(model.numFeatures === model2.numFeatures)
+ assert(model.featureImportances == model2.featureImportances)
}
val gbt = new GBTClassifier()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 5f9ab98a2c3ce..a8c4f091b2aed 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -103,7 +103,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
case Bernoulli =>
expectedBernoulliProbabilities(model, features)
case _ =>
- throw new UnknownError(s"Invalid modelType: $modelType.")
+ throw new IllegalArgumentException(s"Invalid modelType: $modelType.")
}
assert(probability ~== expected relTol 1.0e-10)
}
@@ -378,7 +378,7 @@ object NaiveBayesSuite {
counts.toArray.sortBy(_._1).map(_._2)
case _ =>
// This should never happen.
- throw new UnknownError(s"Invalid modelType: $modelType.")
+ throw new IllegalArgumentException(s"Invalid modelType: $modelType.")
}
LabeledPoint(y, Vectors.dense(xi))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 4fec5bf1eb812..259e7b62f8a66 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -134,8 +134,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
assert(lrModel1.coefficients ~== lrModel2.coefficients relTol 1E-3)
assert(lrModel1.intercept ~== lrModel2.intercept relTol 1E-3)
case other =>
- throw new AssertionError(s"Loaded OneVsRestModel expected model of type" +
- s" LogisticRegressionModel but found ${other.getClass.getName}")
+ fail("Loaded OneVsRestModel expected model of type LogisticRegressionModel " +
+ s"but found ${other.getClass.getName}")
}
}
@@ -247,8 +247,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
assert(lr.getMaxIter === lr2.getMaxIter)
assert(lr.getRegParam === lr2.getRegParam)
case other =>
- throw new AssertionError(s"Loaded OneVsRest expected classifier of type" +
- s" LogisticRegression but found ${other.getClass.getName}")
+ fail("Loaded OneVsRest expected classifier of type LogisticRegression" +
+ s" but found ${other.getClass.getName}")
}
}
@@ -267,8 +267,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
assert(classifier.getMaxIter === lr2.getMaxIter)
assert(classifier.getRegParam === lr2.getRegParam)
case other =>
- throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" +
- s" LogisticRegression but found ${other.getClass.getName}")
+ fail("Loaded OneVsRestModel expected classifier of type LogisticRegression" +
+ s" but found ${other.getClass.getName}")
}
assert(model.labelMetadata === model2.labelMetadata)
@@ -278,8 +278,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
assert(lrModel1.coefficients === lrModel2.coefficients)
assert(lrModel1.intercept === lrModel2.intercept)
case other =>
- throw new AssertionError(s"Loaded OneVsRestModel expected model of type" +
- s" LogisticRegressionModel but found ${other.getClass.getName}")
+ fail(s"Loaded OneVsRestModel expected model of type LogisticRegressionModel" +
+ s" but found ${other.getClass.getName}")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
index 6734336aac39c..985e396000d05 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
@@ -17,16 +17,16 @@
package org.apache.spark.ml.feature
-import scala.beans.BeanInfo
-
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.Row
-@BeanInfo
-case class DCTTestData(vec: Vector, wantedVec: Vector)
+case class DCTTestData(vec: Vector, wantedVec: Vector) {
+ def getVec: Vector = vec
+ def getWantedVec: Vector = wantedVec
+}
class DCTSuite extends MLTest with DefaultReadWriteTest {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
index 201a335e0d7be..1483d5df4d224 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.ml.feature
-import scala.beans.BeanInfo
-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.{DataFrame, Row}
-
-@BeanInfo
-case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
+case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) {
+ def getInputTokens: Array[String] = inputTokens
+ def getWantedNGrams: Array[String] = wantedNGrams
+}
class NGramSuite extends MLTest with DefaultReadWriteTest {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index b009038bbd833..82af05039653e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -31,7 +31,7 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
val datasetSize = 100000
val numBuckets = 5
- val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input")
+ val df = sc.parallelize(1 to datasetSize).map(_.toDouble).map(Tuple1.apply).toDF("input")
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
@@ -114,8 +114,8 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
val spark = this.spark
import spark.implicits._
- val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
- val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input")
+ val trainDF = sc.parallelize((1 to 100).map(_.toDouble)).map(Tuple1.apply).toDF("input")
+ val testDF = sc.parallelize((-10 to 110).map(_.toDouble)).map(Tuple1.apply).toDF("input")
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
@@ -276,10 +276,10 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)
val data2 = Array.range(1, 40, 2).map(_.toDouble)
val expected2 = Array (0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0,
- 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0)
+ 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0)
val data3 = Array.range(1, 60, 3).map(_.toDouble)
- val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 4.0, 4.0, 5.0,
- 5.0, 5.0, 6.0, 6.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0)
+ val expected3 = Array (0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0,
+ 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0, 9.0, 9.0)
val data = (0 until 20).map { idx =>
(data1(idx), data2(idx), data3(idx), expected1(idx), expected2(idx), expected3(idx))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index be59b0af2c78e..ba8e79f14de95 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -17,14 +17,14 @@
package org.apache.spark.ml.feature
-import scala.beans.BeanInfo
-
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.sql.{DataFrame, Row}
-@BeanInfo
-case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
+case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) {
+ def getRawText: String = rawText
+ def getWantedTokens: Array[String] = wantedTokens
+}
class TokenizerSuite extends MLTest with DefaultReadWriteTest {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index e5675e31bbecf..44b0f8f8ae7d8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.ml.feature
-import scala.beans.{BeanInfo, BeanProperty}
-
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.ml.attribute._
@@ -26,7 +24,7 @@ import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.DataFrame
class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging {
@@ -283,7 +281,9 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging {
points.zip(rows.map(_(0))).foreach {
case (orig: SparseVector, indexed: SparseVector) =>
assert(orig.indices.length == indexed.indices.length)
- case _ => throw new UnknownError("Unit test has a bug in it.") // should never happen
+ case _ =>
+ // should never happen
+ fail("Unit test has a bug in it.")
}
}
}
@@ -337,6 +337,7 @@ class VectorIndexerSuite extends MLTest with DefaultReadWriteTest with Logging {
}
private[feature] object VectorIndexerSuite {
- @BeanInfo
- case class FeatureData(@BeanProperty features: Vector)
+ case class FeatureData(features: Vector) {
+ def getFeatures: Vector = features
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala
index bdceba7887cac..8371c33a209dc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/MatrixUDTSuite.scala
@@ -31,7 +31,7 @@ class MatrixUDTSuite extends SparkFunSuite {
val sm3 = dm3.toSparse
for (m <- Seq(dm1, dm2, dm3, sm1, sm2, sm3)) {
- val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.newInstance()
+ val udt = UDTRegistration.getUDTFor(m.getClass.getName).get.getConstructor().newInstance()
.asInstanceOf[MatrixUDT]
assert(m === udt.deserialize(udt.serialize(m)))
assert(udt.typeName == "matrix")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala
index 6ddb12cb76aac..67c64f762b25e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala
@@ -31,7 +31,7 @@ class VectorUDTSuite extends SparkFunSuite {
val sv2 = Vectors.sparse(2, Array(1), Array(2.0))
for (v <- Seq(dv1, dv2, sv1, sv2)) {
- val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.newInstance()
+ val udt = UDTRegistration.getUDTFor(v.getClass.getName).get.getConstructor().newInstance()
.asInstanceOf[VectorUDT]
assert(v === udt.deserialize(udt.serialize(v)))
assert(udt.typeName == "vector")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 9a59c41740daf..2fc9754ecfe1e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -601,7 +601,7 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging {
val df = maybeDf.get._2
val expected = estimator.fit(df)
- val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2)))
+ val actuals = dfs.map(t => (t, estimator.fit(t._2)))
actuals.foreach { case (_, actual) => check(expected, actual) }
actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 743dacf146fe7..5caa5117d5752 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -417,9 +417,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
case n: InternalNode => n.split match {
case s: CategoricalSplit =>
assert(s.leftCategories === Array(1.0))
- case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit")
+ case _ => fail("model.rootNode.split was not a CategoricalSplit")
}
- case _ => throw new AssertionError("model.rootNode was not an InternalNode")
+ case _ => fail("model.rootNode was not an InternalNode")
}
}
@@ -444,7 +444,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(n.leftChild.isInstanceOf[InternalNode])
assert(n.rightChild.isInstanceOf[InternalNode])
Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode])
- case _ => throw new AssertionError("rootNode was not an InternalNode")
+ case _ => fail("rootNode was not an InternalNode")
}
// Single group second level tree construction.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index b6894b30b0c2b..ae9794b87b08d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -112,7 +112,7 @@ private[ml] object TreeTests extends SparkFunSuite {
checkEqual(a.rootNode, b.rootNode)
} catch {
case ex: Exception =>
- throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+ fail("checkEqual failed since the two trees were not identical.\n" +
"TREE A:\n" + a.toDebugString + "\n" +
"TREE B:\n" + b.toDebugString + "\n", ex)
}
@@ -133,7 +133,7 @@ private[ml] object TreeTests extends SparkFunSuite {
checkEqual(aye.rightChild, bee.rightChild)
case (aye: LeafNode, bee: LeafNode) => // do nothing
case _ =>
- throw new AssertionError("Found mismatched nodes")
+ fail("Found mismatched nodes")
}
}
@@ -148,7 +148,7 @@ private[ml] object TreeTests extends SparkFunSuite {
}
assert(a.treeWeights === b.treeWeights)
} catch {
- case ex: Exception => throw new AssertionError(
+ case ex: Exception => fail(
"checkEqual failed since the two tree ensembles were not identical")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index e6ee7220d2279..a30428ec2d283 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -190,8 +190,8 @@ class CrossValidatorSuite
assert(lr.uid === lr2.uid)
assert(lr.getMaxIter === lr2.getMaxIter)
case other =>
- throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
- s" LogisticRegression but found ${other.getClass.getName}")
+ fail("Loaded CrossValidator expected estimator of type LogisticRegression" +
+ s" but found ${other.getClass.getName}")
}
ValidatorParamsSuiteHelpers
@@ -281,13 +281,13 @@ class CrossValidatorSuite
assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter
=== lr.getMaxIter)
case other =>
- throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
- s" LogisticRegression but found ${other.getClass.getName}")
+ fail("Loaded CrossValidator expected estimator of type LogisticRegression" +
+ s" but found ${other.getClass.getName}")
}
case other =>
- throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
- s" OneVsRest but found ${other.getClass.getName}")
+ fail("Loaded CrossValidator expected estimator of type OneVsRest but " +
+ s"found ${other.getClass.getName}")
}
ValidatorParamsSuiteHelpers
@@ -364,8 +364,8 @@ class CrossValidatorSuite
assert(lr.uid === lr2.uid)
assert(lr.getMaxIter === lr2.getMaxIter)
case other =>
- throw new AssertionError(s"Loaded internal CrossValidator expected to be" +
- s" LogisticRegression but found type ${other.getClass.getName}")
+ fail("Loaded internal CrossValidator expected to be LogisticRegression" +
+ s" but found type ${other.getClass.getName}")
}
assert(lrcv.uid === lrcv2.uid)
assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
@@ -373,12 +373,12 @@ class CrossValidatorSuite
ValidatorParamsSuiteHelpers
.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
case other =>
- throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)" +
- " but found: " + other.map(_.getClass.getName).mkString(", "))
+ fail("Loaded Pipeline expected stages (HashingTF, CrossValidator) but found: " +
+ other.map(_.getClass.getName).mkString(", "))
}
case other =>
- throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
- s" CrossValidator but found ${other.getClass.getName}")
+ fail("Loaded CrossValidator expected estimator of type CrossValidator but found" +
+ s" ${other.getClass.getName}")
}
}
@@ -433,8 +433,8 @@ class CrossValidatorSuite
assert(lr.uid === lr2.uid)
assert(lr.getThreshold === lr2.getThreshold)
case other =>
- throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
- s" LogisticRegression but found ${other.getClass.getName}")
+ fail("Loaded CrossValidator expected estimator of type LogisticRegression" +
+ s" but found ${other.getClass.getName}")
}
ValidatorParamsSuiteHelpers
@@ -447,8 +447,8 @@ class CrossValidatorSuite
assert(lrModel.coefficients === lrModel2.coefficients)
assert(lrModel.intercept === lrModel2.intercept)
case other =>
- throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" +
- s" LogisticRegressionModel but found ${other.getClass.getName}")
+ fail("Loaded CrossValidator expected bestModel of type LogisticRegressionModel" +
+ s" but found ${other.getClass.getName}")
}
assert(cv.avgMetrics === cv2.avgMetrics)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index cd76acf9c67bc..289db336eca5d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -187,8 +187,8 @@ class TrainValidationSplitSuite
assert(lr.uid === lr2.uid)
assert(lr.getMaxIter === lr2.getMaxIter)
case other =>
- throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" +
- s" LogisticRegression but found ${other.getClass.getName}")
+ fail("Loaded TrainValidationSplit expected estimator of type LogisticRegression" +
+ s" but found ${other.getClass.getName}")
}
}
@@ -264,13 +264,13 @@ class TrainValidationSplitSuite
assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter
=== lr.getMaxIter)
case other =>
- throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" +
- s" LogisticRegression but found ${other.getClass.getName}")
+ fail(s"Loaded TrainValidationSplit expected estimator of type LogisticRegression" +
+ s" but found ${other.getClass.getName}")
}
case other =>
- throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" +
- s" OneVsRest but found ${other.getClass.getName}")
+ fail(s"Loaded TrainValidationSplit expected estimator of type OneVsRest" +
+ s" but found ${other.getClass.getName}")
}
ValidatorParamsSuiteHelpers
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala
index eae1f5adc8842..cea2f50d3470c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala
@@ -47,8 +47,7 @@ object ValidatorParamsSuiteHelpers extends Assertions {
val estimatorParamMap2 = Array(estimator2.extractParamMap())
compareParamMaps(estimatorParamMap, estimatorParamMap2)
case other =>
- throw new AssertionError(s"Expected parameter of type Params but" +
- s" found ${otherParam.getClass.getName}")
+ fail(s"Expected parameter of type Params but found ${otherParam.getClass.getName}")
}
case _ =>
assert(otherParam === v)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 5ec4c15387e94..8c7d583923b32 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -71,7 +71,7 @@ object NaiveBayesSuite {
counts.toArray.sortBy(_._1).map(_._2)
case _ =>
// This should never happen.
- throw new UnknownError(s"Invalid modelType: $modelType.")
+ throw new IllegalArgumentException(s"Invalid modelType: $modelType.")
}
LabeledPoint(y, Vectors.dense(xi))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 1b98250061c7a..d18cef7e264db 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -349,7 +349,7 @@ object KMeansSuite extends SparkFunSuite {
case (ca: DenseVector, cb: DenseVector) =>
assert(ca === cb)
case _ =>
- throw new AssertionError("checkEqual failed since the two clusters were not identical.\n")
+ fail("checkEqual failed since the two clusters were not identical.\n")
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
index 5394baab94bcf..8779de590a256 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
@@ -18,10 +18,14 @@
package org.apache.spark.mllib.evaluation
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.linalg.Matrices
+import org.apache.spark.ml.linalg.Matrices
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ private val delta = 1e-7
+
test("Multiclass evaluation metrics") {
/*
* Confusion matrix for 3-class classification with total 9 instances:
@@ -35,7 +39,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
val metrics = new MulticlassMetrics(predictionAndLabels)
- val delta = 0.0000001
val tpRate0 = 2.0 / (2 + 2)
val tpRate1 = 3.0 / (3 + 1)
val tpRate2 = 1.0 / (1 + 0)
@@ -55,41 +58,122 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)
- assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray))
- assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta)
- assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta)
- assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta)
- assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
- assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
- assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)
- assert(math.abs(metrics.precision(0.0) - precision0) < delta)
- assert(math.abs(metrics.precision(1.0) - precision1) < delta)
- assert(math.abs(metrics.precision(2.0) - precision2) < delta)
- assert(math.abs(metrics.recall(0.0) - recall0) < delta)
- assert(math.abs(metrics.recall(1.0) - recall1) < delta)
- assert(math.abs(metrics.recall(2.0) - recall2) < delta)
- assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta)
- assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
- assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
- assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta)
- assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta)
- assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta)
+ assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta)
+ assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta)
+ assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta)
+ assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta)
+ assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta)
+ assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta)
+ assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta)
+ assert(metrics.precision(0.0) ~== precision0 relTol delta)
+ assert(metrics.precision(1.0) ~== precision1 relTol delta)
+ assert(metrics.precision(2.0) ~== precision2 relTol delta)
+ assert(metrics.recall(0.0) ~== recall0 relTol delta)
+ assert(metrics.recall(1.0) ~== recall1 relTol delta)
+ assert(metrics.recall(2.0) ~== recall2 relTol delta)
+ assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta)
+ assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta)
+ assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta)
+ assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta)
+ assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta)
+ assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta)
+
+ assert(metrics.accuracy ~==
+ (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) relTol delta)
+ assert(metrics.accuracy ~== metrics.weightedRecall relTol delta)
+ val weight0 = 4.0 / 9
+ val weight1 = 4.0 / 9
+ val weight2 = 1.0 / 9
+ assert(metrics.weightedTruePositiveRate ~==
+ (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta)
+ assert(metrics.weightedFalsePositiveRate ~==
+ (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta)
+ assert(metrics.weightedPrecision ~==
+ (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta)
+ assert(metrics.weightedRecall ~==
+ (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta)
+ assert(metrics.weightedFMeasure ~==
+ (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta)
+ assert(metrics.weightedFMeasure(2.0) ~==
+ (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta)
+ assert(metrics.labels === labels)
+ }
+
+ test("Multiclass evaluation metrics with weights") {
+ /*
+ * Confusion matrix for 3-class classification with total 9 instances with 2 weights:
+ * |2 * w1|1 * w2 |1 * w1| true class0 (4 instances)
+ * |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances)
+ * |0 |0 |1 * w2| true class2 (1 instance)
+ */
+ val w1 = 2.2
+ val w2 = 1.5
+ val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2
+ val confusionMatrix = Matrices.dense(3, 3,
+ Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * w2))
+ val labels = Array(0.0, 1.0, 2.0)
+ val predictionAndLabelsWithWeights = sc.parallelize(
+ Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2),
+ (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2),
+ (2.0, 0.0, w1)), 2)
+ val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights)
+ val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1)
+ val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
+ val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0)
+ val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1))
+ val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2))
+ val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2))
+ val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2)
+ val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
+ val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2)
+ val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1)
+ val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
+ val recall2 = (1.0 * w2) / (1.0 * w2 + 0)
+ val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
+ val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
+ val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
+ val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0)
+ val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
+ val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)
+
+ assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta)
+ assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta)
+ assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta)
+ assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta)
+ assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta)
+ assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta)
+ assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta)
+ assert(metrics.precision(0.0) ~== precision0 relTol delta)
+ assert(metrics.precision(1.0) ~== precision1 relTol delta)
+ assert(metrics.precision(2.0) ~== precision2 relTol delta)
+ assert(metrics.recall(0.0) ~== recall0 relTol delta)
+ assert(metrics.recall(1.0) ~== recall1 relTol delta)
+ assert(metrics.recall(2.0) ~== recall2 relTol delta)
+ assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta)
+ assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta)
+ assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta)
+ assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta)
+ assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta)
+ assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta)
- assert(math.abs(metrics.accuracy -
- (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
- assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta)
- assert(math.abs(metrics.weightedTruePositiveRate -
- ((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta)
- assert(math.abs(metrics.weightedFalsePositiveRate -
- ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta)
- assert(math.abs(metrics.weightedPrecision -
- ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta)
- assert(math.abs(metrics.weightedRecall -
- ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta)
- assert(math.abs(metrics.weightedFMeasure -
- ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta)
- assert(math.abs(metrics.weightedFMeasure(2.0) -
- ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta)
- assert(metrics.labels.sameElements(labels))
+ assert(metrics.accuracy ~==
+ (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta)
+ assert(metrics.accuracy ~== metrics.weightedRecall relTol delta)
+ val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw
+ val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw
+ val weight2 = 1 * w2 / tw
+ assert(metrics.weightedTruePositiveRate ~==
+ (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta)
+ assert(metrics.weightedFalsePositiveRate ~==
+ (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta)
+ assert(metrics.weightedPrecision ~==
+ (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta)
+ assert(metrics.weightedRecall ~==
+ (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta)
+ assert(metrics.weightedFMeasure ~==
+ (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta)
+ assert(metrics.weightedFMeasure(2.0) ~==
+ (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta)
+ assert(metrics.labels === labels)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index d76edb940b2bd..2c3f84617cfa5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -511,10 +511,10 @@ class MatricesSuite extends SparkFunSuite {
mat.toString(0, 0)
mat.toString(Int.MinValue, Int.MinValue)
mat.toString(Int.MaxValue, Int.MaxValue)
- var lines = mat.toString(6, 50).lines.toArray
+ var lines = mat.toString(6, 50).split('\n')
assert(lines.size == 5 && lines.forall(_.size <= 50))
- lines = mat.toString(5, 100).lines.toArray
+ lines = mat.toString(5, 100).split('\n')
assert(lines.size == 5 && lines.forall(_.size <= 100))
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
index 669d44223d713..5b4a2607f0b25 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
@@ -17,10 +17,11 @@
package org.apache.spark.mllib.stat.distribution
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.mllib.linalg.{Matrices, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.serializer.KryoSerializer
class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext {
test("univariate") {
@@ -80,4 +81,23 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext
assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
}
+ test("Kryo class register") {
+ val conf = new SparkConf(false)
+ conf.set("spark.kryo.registrationRequired", "true")
+
+ val ser = new KryoSerializer(conf).newInstance()
+
+ val mu = Vectors.dense(0.0, 0.0)
+ val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
+ val dist1 = new MultivariateGaussian(mu, sigma1)
+
+ val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
+ val dist2 = new MultivariateGaussian(mu, sigma2)
+
+ Seq(dist1, dist2).foreach { i =>
+ val i2 = ser.deserialize[MultivariateGaussian](ser.serialize(i))
+ assert(i.sigma === i2.sigma)
+ assert(i.mu === i2.mu)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index bc59f3f4125fb..34bc303ac6079 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -607,7 +607,7 @@ object DecisionTreeSuite extends SparkFunSuite {
checkEqual(a.topNode, b.topNode)
} catch {
case ex: Exception =>
- throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+ fail("checkEqual failed since the two trees were not identical.\n" +
"TREE A:\n" + a.toDebugString + "\n" +
"TREE B:\n" + b.toDebugString + "\n", ex)
}
@@ -628,20 +628,21 @@ object DecisionTreeSuite extends SparkFunSuite {
// TODO: Check other fields besides the information gain.
case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain)
case (None, None) =>
- case _ => throw new AssertionError(
- s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})")
+ case _ => fail(s"Only one instance has stats defined. (a.stats: ${a.stats}, " +
+ s"b.stats: ${b.stats})")
}
(a.leftNode, b.leftNode) match {
case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode)
case (None, None) =>
- case _ => throw new AssertionError("Only one instance has leftNode defined. " +
- s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})")
+ case _ =>
+ fail("Only one instance has leftNode defined. (a.leftNode: ${a.leftNode}," +
+ " b.leftNode: ${b.leftNode})")
}
(a.rightNode, b.rightNode) match {
case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode)
case (None, None) =>
- case _ => throw new AssertionError("Only one instance has rightNode defined. " +
- s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})")
+ case _ => fail("Only one instance has rightNode defined. (a.rightNode: ${a.rightNode}, " +
+ "b.rightNode: ${b.rightNode})")
}
}
}
diff --git a/pom.xml b/pom.xml
index 247b77c71928b..0654831b70582 100644
--- a/pom.xml
+++ b/pom.xml
@@ -25,7 +25,7 @@
18org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOTpomSpark Project Parent POM
@@ -116,14 +116,14 @@
1.8${java.version}${java.version}
- 3.5.4
+ 3.6.0spark1.7.251.2.17
- 2.7.3
+ 2.7.42.5.0${hadoop.version}
- 3.4.6
+ 3.4.72.7.1org.spark-project.hive
@@ -164,8 +164,8 @@
3.6.13.2.2
- 2.11.12
- 2.11
+ 2.12.7
+ 2.121.9.132.9.71.1.7.2
@@ -176,8 +176,8 @@
2.6
- 3.8
- 1.8.1
+ 3.8.1
+ 1.181.63.2.101.1.1
@@ -188,7 +188,7 @@
3.5.23.0.20.9.3
- 4.7
+ 4.7.13.41.12.52.0
@@ -482,6 +482,11 @@
commons-lang3${commons-lang3.version}
+
+ org.apache.commons
+ commons-text
+ 1.6
+ commons-langcommons-lang
@@ -2264,7 +2269,7 @@
org.apache.maven.pluginsmaven-enforcer-plugin
- 3.0.0-M1
+ 3.0.0-M2enforce-versions
@@ -2290,6 +2295,7 @@
-->
org.jboss.nettyorg.codehaus.groovy
+ *:*_2.11*:*_2.10true
@@ -2307,8 +2313,7 @@
net.alchim31.mavenscala-maven-plugin
-
- 3.2.2
+ 3.4.4eclipse-add-source
@@ -2328,6 +2333,13 @@
testCompile
+
+ attach-scaladocs
+ verify
+
+ doc-jar
+
+ ${scala.version}
@@ -2357,7 +2369,7 @@
org.apache.maven.pluginsmaven-compiler-plugin
- 3.7.0
+ 3.8.0${java.version}${java.version}
@@ -2374,7 +2386,7 @@
org.apache.maven.pluginsmaven-surefire-plugin
- 2.22.0
+ 3.0.0-M1
@@ -2428,7 +2440,7 @@
org.scalatestscalatest-maven-plugin
- 1.0
+ 2.0.0${project.build.directory}/surefire-reports
@@ -2475,7 +2487,7 @@
org.apache.maven.pluginsmaven-jar-plugin
- 3.0.2
+ 3.1.0org.apache.maven.plugins
@@ -2502,7 +2514,7 @@
org.apache.maven.pluginsmaven-clean-plugin
- 3.0.0
+ 3.1.0
@@ -2520,9 +2532,12 @@
org.apache.maven.pluginsmaven-javadoc-plugin
- 3.0.0-M1
+ 3.0.1
- -Xdoclint:all -Xdoclint:-missing
+
+ -Xdoclint:all
+ -Xdoclint:-missing
+ example
@@ -2573,22 +2588,34 @@
org.apache.maven.pluginsmaven-shade-plugin
- 3.1.0
+ 3.2.1
+
+
+ org.ow2.asm
+ asm
+ 7.0
+
+
+ org.ow2.asm
+ asm-commons
+ 7.0
+
+ org.apache.maven.pluginsmaven-install-plugin
- 2.5.2
+ 3.0.0-M1org.apache.maven.pluginsmaven-deploy-plugin
- 2.8.2
+ 3.0.0-M1org.apache.maven.pluginsmaven-dependency-plugin
- 3.0.2
+ 3.1.1default-cli
@@ -2629,7 +2656,7 @@
org.apache.maven.pluginsmaven-jar-plugin
- [2.6,)
+ 3.1.0test-jar
@@ -2786,12 +2813,17 @@
org.apache.maven.pluginsmaven-checkstyle-plugin
- 2.17
+ 3.0.0falsetrue
- ${basedir}/src/main/java,${basedir}/src/main/scala
- ${basedir}/src/test/java
+
+ ${basedir}/src/main/java
+ ${basedir}/src/main/scala
+
+
+ ${basedir}/src/test/java
+ dev/checkstyle.xml${basedir}/target/checkstyle-output.xml${project.build.sourceEncoding}
@@ -2801,7 +2833,7 @@
com.puppycrawl.toolscheckstyle
- 8.2
+ 8.14
@@ -3027,14 +3059,14 @@
- scala-2.11
+ scala-2.12
- scala-2.12
+ scala-2.11
- 2.12.7
- 2.12
+ 2.11.12
+ 2.11
@@ -3050,8 +3082,9 @@
-
- *:*_2.11
+
+ org.jboss.netty
+ org.codehaus.groovy*:*_2.10
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index adde213e361f0..10c02103aeddb 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -88,9 +88,9 @@ object MimaBuild {
def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
val organization = "org.apache.spark"
- val previousSparkVersion = "2.2.0"
+ val previousSparkVersion = "2.4.0"
val project = projectRef.project
- val fullId = "spark-" + project + "_2.11"
+ val fullId = "spark-" + project + "_2.12"
mimaDefaultSettings ++
Seq(mimaPreviousArtifacts := Set(organization % fullId % previousSparkVersion),
mimaBinaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value))
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index f2e943cae6117..842730e7deb13 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,32 @@ object MimaExcludes {
// Exclude rules for 3.0.x
lazy val v30excludes = v24excludes ++ Seq(
+ // [SPARK-26124] Update plugins, including MiMa
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns.build"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.fullSchema"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.planInputPartitions"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.fullSchema"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.planInputPartitions"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters.build"),
+
+ // [SPARK-26090] Resolve most miscellaneous deprecation and build warnings for Spark 3
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.stat.test.BinarySampleBeanInfo"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LabeledPointBeanInfo"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"),
+
+ // [SPARK-25959] GBTClassifier picks wrong impurity stats on loading
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.tree.HasVarianceImpurity.org$apache$spark$ml$tree$HasVarianceImpurity$_setter_$impurity_="),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"),
+
+ // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleBytesWritten"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleWriteTime"),
@@ -50,10 +76,13 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.precision"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLWriter.context"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLReader.context"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.GeneralMLWriter.context"),
+
// [SPARK-25737] Remove JavaSparkContextVarargsWorkaround
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.api.java.JavaSparkContext"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.union"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.union"),
+
// [SPARK-16775] Remove deprecated accumulator v1 APIs
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulable"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam"),
@@ -73,14 +102,61 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulable"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.doubleAccumulator"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulator"),
+
// [SPARK-24109] Remove class SnappyOutputStreamWrapper
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"),
+
// [SPARK-19287] JavaPairRDD flatMapValues requires function returning Iterable, not Iterator
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"),
+
// [SPARK-25680] SQL execution listener shouldn't happen on execution thread
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this"),
+
+ // [SPARK-23781][CORE] Merge token renewer functionality into HadoopDelegationTokenManager
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.nextCredentialRenewalTime"),
+
+ // Data Source V2 API changes
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.ContinuousReadSupport"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.ReadSupport"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.WriteSupport"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.StreamWriteSupport"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.MicroBatchReadSupport"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder.build"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.InputPartition.createPartitionReader"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.estimateStatistics"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.fullSchema"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.planInputPartitions"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.outputPartitioning"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning.outputPartitioning"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.fullSchema"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ReadSupport.planInputPartitions"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.sources.v2.reader.ScanConfigBuilder.build"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.ContinuousInputPartition"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.InputPartitionReader"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.DataSourceWriter"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createWriter"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter"),
+
+ // [SPARK-26141] Enable custom metrics implementation in shuffle write
+ // Following are Java private classes
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this"),
+
+ // SafeLogging after MimaUpgrade
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.initializeLogIfNecessary$default$2"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.initializeLogIfNecessary$default$2"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.broadcast.Broadcast.initializeLogIfNecessary$default$2")
)
// Exclude rules for 2.4.x
@@ -214,7 +290,50 @@ object MimaExcludes {
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"),
// [SPARK-23042] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter")
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter"),
+
+ // [SPARK-21842][MESOS] Support Kerberos ticket renewal and creation in Mesos
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getDateOfNextUpdate"),
+
+ // [SPARK-23366] Improve hot reading path in ReadAheadInputStream
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.ReadAheadInputStream.this"),
+
+ // [SPARK-22941][CORE] Do not exit JVM when submit fails with in-process launcher.
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.addJarToClasspath"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.mergeFileLists"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment$default$2"),
+
+ // Data Source V2 API changes
+ // TODO: they are unstable APIs and should not be tracked by mima.
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.ReadSupportWithSchema"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.createDataReaderFactories"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.createBatchDataReaderFactories"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.planBatchInputPartitions"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanUnsafeRow"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader.createDataReaderFactories"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader.planInputPartitions"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownCatalystFilters"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataReader"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.getStatistics"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.estimateStatistics"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataReaderFactory"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createDataWriter"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createDataWriter"),
+
+ // Changes to HasRawPredictionCol.
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.rawPredictionCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.org$apache$spark$ml$param$shared$HasRawPredictionCol$_setter_$rawPredictionCol_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.getRawPredictionCol"),
+
+ // [SPARK-15526][ML][FOLLOWUP] Make JPMML provided scope to avoid including unshaded JARs
+ (problem: Problem) => problem match {
+ case MissingClassProblem(cls) =>
+ !cls.fullName.startsWith("org.spark_project.jpmml") &&
+ !cls.fullName.startsWith("org.spark_project.dmg.pmml")
+ case _ => true
+ }
)
// Exclude rules for 2.3.x
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 2f11c5deda217..639d642b78d2a 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -17,6 +17,7 @@
import java.io._
import java.nio.file.Files
+import java.util.Locale
import scala.io.Source
import scala.util.Properties
@@ -95,15 +96,15 @@ object SparkBuild extends PomBuild {
}
Option(System.getProperty("scala.version"))
- .filter(_.startsWith("2.12"))
+ .filter(_.startsWith("2.11"))
.foreach { versionString =>
- System.setProperty("scala-2.12", "true")
+ System.setProperty("scala-2.11", "true")
}
- if (System.getProperty("scala-2.12") == "") {
+ if (System.getProperty("scala-2.11") == "") {
// To activate scala-2.10 profile, replace empty property value to non-empty value
// in the same way as Maven which handles -Dname as -Dname=true before executes build process.
// see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082
- System.setProperty("scala-2.12", "true")
+ System.setProperty("scala-2.11", "true")
}
profiles
}
@@ -522,7 +523,8 @@ object KubernetesIntegrationTests {
s"-Dspark.kubernetes.test.unpackSparkDir=$sparkHome"
),
// Force packaging before building images, so that the latest code is tested.
- dockerBuild := dockerBuild.dependsOn(packageBin in Compile in assembly).value
+ dockerBuild := dockerBuild.dependsOn(packageBin in Compile in assembly)
+ .dependsOn(packageBin in Compile in examples).value
)
}
@@ -657,10 +659,13 @@ object Assembly {
},
jarName in (Test, assembly) := s"${moduleName.value}-test-${version.value}.jar",
mergeStrategy in assembly := {
- case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
- case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
+ case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf")
+ => MergeStrategy.discard
+ case m if m.toLowerCase(Locale.ROOT).matches("meta-inf.*\\.sf$")
+ => MergeStrategy.discard
case "log4j.properties" => MergeStrategy.discard
- case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines
+ case m if m.toLowerCase(Locale.ROOT).startsWith("meta-inf/services/")
+ => MergeStrategy.filterDistinctLines
case "reference.conf" => MergeStrategy.concat
case _ => MergeStrategy.first
}
@@ -852,10 +857,10 @@ object TestSettings {
import BuildCommons._
private val scalaBinaryVersion =
- if (System.getProperty("scala-2.12") == "true") {
- "2.12"
- } else {
+ if (System.getProperty("scala-2.11") == "true") {
"2.11"
+ } else {
+ "2.12"
}
lazy val settings = Seq (
// Fork new JVMs for tests and set Java options for those
diff --git a/project/plugins.sbt b/project/plugins.sbt
index da4fd524d6682..e2cfbe7bd6d4b 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -3,7 +3,7 @@ addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "0.8.5")
addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1")
// sbt-checkstyle-plugin uses an old version of checkstyle. Match it to Maven's.
-libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.2"
+libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "8.14"
// checkstyle uses guava 23.0.
libraryDependencies += "com.google.guava" % "guava" % "23.0"
@@ -11,13 +11,13 @@ libraryDependencies += "com.google.guava" % "guava" % "23.0"
// need to make changes to uptake sbt 1.0 support in "com.eed3si9n" % "sbt-assembly" % "1.14.5"
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2")
-addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.3")
+addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.2.4")
-addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.0")
+addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.9.2")
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0")
-addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.17")
+addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.3.0")
// sbt 1.0.0 support: https://github.com/AlpineNow/junit_xml_listener/issues/6
addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1")
@@ -30,12 +30,12 @@ addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2")
addSbtPlugin("io.spray" % "sbt-revolver" % "0.9.1")
-libraryDependencies += "org.ow2.asm" % "asm" % "5.1"
+libraryDependencies += "org.ow2.asm" % "asm" % "7.0"
-libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.1"
+libraryDependencies += "org.ow2.asm" % "asm-commons" % "7.0"
// sbt 1.0.0 support: https://github.com/ihji/sbt-antlr4/issues/14
-addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.11")
+addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.12")
// Spark uses a custom fork of the sbt-pom-reader plugin which contains a patch to fix issues
// related to test-jar dependencies (https://github.com/sbt/sbt-pom-reader/pull/14). The source for
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index b37129428f491..aaeeeb82d3d86 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -388,7 +388,7 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol
>>> len(centers)
2
>>> model.computeCost(df)
- 2.000...
+ 2.0
>>> transformed = model.transform(df).select("features", "prediction")
>>> rows = transformed.collect()
>>> rows[0].prediction == rows[1].prediction
@@ -403,7 +403,7 @@ class KMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPredictionCol
>>> summary.clusterSizes
[2, 2]
>>> summary.trainingCost
- 2.000...
+ 2.0
>>> kmeans_path = temp_path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
@@ -595,7 +595,7 @@ class BisectingKMeans(JavaEstimator, HasDistanceMeasure, HasFeaturesCol, HasPred
>>> len(centers)
2
>>> model.computeCost(df)
- 2.000...
+ 2.0
>>> model.hasSummary
True
>>> summary = model.summary
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index eccb7acae5b98..3d23700242594 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -361,8 +361,9 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
"splits specified will be treated as errors.",
typeConverter=TypeConverters.toListFloat)
- handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
- "Options are 'skip' (filter out rows with invalid values), " +
+ handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries "
+ "containing NaN values. Values outside the splits will always be treated "
+ "as errors. Options are 'skip' (filter out rows with invalid values), " +
"'error' (throw an error), or 'keep' (keep invalid values in a special " +
"additional bucket).",
typeConverter=TypeConverters.toString)
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 886ad8409ca66..734763ebd3fa6 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -167,8 +167,8 @@ class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol,
independent group of mining tasks. The FP-Growth algorithm is described in
Han et al., Mining frequent patterns without candidate generation [HAN2000]_
- .. [LI2008] http://dx.doi.org/10.1145/1454008.1454027
- .. [HAN2000] http://dx.doi.org/10.1145/335191.335372
+ .. [LI2008] https://doi.org/10.1145/1454008.1454027
+ .. [HAN2000] https://doi.org/10.1145/335191.335372
.. note:: null values in the feature column are ignored during fit().
.. note:: Internally `transform` `collects` and `broadcasts` association rules.
@@ -254,7 +254,7 @@ class PrefixSpan(JavaParams):
A parallel PrefixSpan algorithm to mine frequent sequential patterns.
The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
Efficiently by Prefix-Projected Pattern Growth
- (see here).
+ (see here).
This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns`
method to run the PrefixSpan algorithm.
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index a8eae9bd268d3..520d7912c1a10 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -57,7 +57,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
For implicit preference data, the algorithm used is based on
`"Collaborative Filtering for Implicit Feedback Datasets",
- `_, adapted for the blocked
+ `_, adapted for the blocked
approach used here.
Essentially instead of finding the low-rank approximations to the
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
deleted file mode 100755
index 8c4f02dd724b4..0000000000000
--- a/python/pyspark/ml/tests.py
+++ /dev/null
@@ -1,2761 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-Unit tests for MLlib Python DataFrame-based APIs.
-"""
-import sys
-
-import unishark
-
-if sys.version > '3':
- xrange = range
- basestring = str
-
-if sys.version_info[:2] <= (2, 6):
- try:
- import unittest2 as unittest
- except ImportError:
- sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
- sys.exit(1)
-else:
- import unittest
-
-from shutil import rmtree
-import tempfile
-import array as pyarray
-import numpy as np
-from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros
-import inspect
-import py4j
-
-from pyspark import keyword_only, SparkContext
-from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer
-from pyspark.ml.classification import *
-from pyspark.ml.clustering import *
-from pyspark.ml.common import _java2py, _py2java
-from pyspark.ml.evaluation import BinaryClassificationEvaluator, ClusteringEvaluator, \
- MulticlassClassificationEvaluator, RegressionEvaluator
-from pyspark.ml.feature import *
-from pyspark.ml.fpm import FPGrowth, FPGrowthModel
-from pyspark.ml.image import ImageSchema
-from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \
- SparseMatrix, SparseVector, Vector, VectorUDT, Vectors
-from pyspark.ml.param import Param, Params, TypeConverters
-from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed
-from pyspark.ml.recommendation import ALS
-from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \
- LinearRegression
-from pyspark.ml.stat import ChiSquareTest
-from pyspark.ml.tuning import *
-from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaParams, JavaWrapper
-from pyspark.serializers import PickleSerializer
-from pyspark.sql import DataFrame, Row, SparkSession, HiveContext
-from pyspark.sql.functions import rand
-from pyspark.sql.types import DoubleType, IntegerType
-from pyspark.storagelevel import *
-from pyspark.tests import QuietTest, ReusedPySparkTestCase as PySparkTestCase
-
-ser = PickleSerializer()
-
-
-class MLlibTestCase(unittest.TestCase):
- def setUp(self):
- self.sc = SparkContext('local[4]', "MLlib tests")
- self.spark = SparkSession(self.sc)
-
- def tearDown(self):
- self.spark.stop()
-
-
-class SparkSessionTestCase(PySparkTestCase):
- @classmethod
- def setUpClass(cls):
- PySparkTestCase.setUpClass()
- cls.spark = SparkSession(cls.sc)
-
- @classmethod
- def tearDownClass(cls):
- PySparkTestCase.tearDownClass()
- cls.spark.stop()
-
-
-class MockDataset(DataFrame):
-
- def __init__(self):
- self.index = 0
-
-
-class HasFake(Params):
-
- def __init__(self):
- super(HasFake, self).__init__()
- self.fake = Param(self, "fake", "fake param")
-
- def getFake(self):
- return self.getOrDefault(self.fake)
-
-
-class MockTransformer(Transformer, HasFake):
-
- def __init__(self):
- super(MockTransformer, self).__init__()
- self.dataset_index = None
-
- def _transform(self, dataset):
- self.dataset_index = dataset.index
- dataset.index += 1
- return dataset
-
-
-class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):
-
- shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
- "data in a DataFrame",
- typeConverter=TypeConverters.toFloat)
-
- def __init__(self, shiftVal=1):
- super(MockUnaryTransformer, self).__init__()
- self._setDefault(shift=1)
- self._set(shift=shiftVal)
-
- def getShift(self):
- return self.getOrDefault(self.shift)
-
- def setShift(self, shift):
- self._set(shift=shift)
-
- def createTransformFunc(self):
- shiftVal = self.getShift()
- return lambda x: x + shiftVal
-
- def outputDataType(self):
- return DoubleType()
-
- def validateInputType(self, inputType):
- if inputType != DoubleType():
- raise TypeError("Bad input type: {}. ".format(inputType) +
- "Requires Double.")
-
-
-class MockEstimator(Estimator, HasFake):
-
- def __init__(self):
- super(MockEstimator, self).__init__()
- self.dataset_index = None
-
- def _fit(self, dataset):
- self.dataset_index = dataset.index
- model = MockModel()
- self._copyValues(model)
- return model
-
-
-class MockModel(MockTransformer, Model, HasFake):
- pass
-
-
-class JavaWrapperMemoryTests(SparkSessionTestCase):
-
- def test_java_object_gets_detached(self):
- df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
- (0.0, 2.0, Vectors.sparse(1, [], []))],
- ["label", "weight", "features"])
- lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight",
- fitIntercept=False)
-
- model = lr.fit(df)
- summary = model.summary
-
- self.assertIsInstance(model, JavaWrapper)
- self.assertIsInstance(summary, JavaWrapper)
- self.assertIsInstance(model, JavaParams)
- self.assertNotIsInstance(summary, JavaParams)
-
- error_no_object = 'Target Object ID does not exist for this gateway'
-
- self.assertIn("LinearRegression_", model._java_obj.toString())
- self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
-
- model.__del__()
-
- with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
- model._java_obj.toString()
- self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
-
- try:
- summary.__del__()
- except:
- pass
-
- with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
- model._java_obj.toString()
- with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
- summary._java_obj.toString()
-
-
-class ParamTypeConversionTests(PySparkTestCase):
- """
- Test that param type conversion happens.
- """
-
- def test_int(self):
- lr = LogisticRegression(maxIter=5.0)
- self.assertEqual(lr.getMaxIter(), 5)
- self.assertTrue(type(lr.getMaxIter()) == int)
- self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt"))
- self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1))
-
- def test_float(self):
- lr = LogisticRegression(tol=1)
- self.assertEqual(lr.getTol(), 1.0)
- self.assertTrue(type(lr.getTol()) == float)
- self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat"))
-
- def test_vector(self):
- ewp = ElementwiseProduct(scalingVec=[1, 3])
- self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0]))
- ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4]))
- self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4]))
- self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"]))
-
- def test_list(self):
- l = [0, 1]
- for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l),
- range(len(l)), l), pyarray.array('l', l), xrange(2), tuple(l)]:
- converted = TypeConverters.toList(lst_like)
- self.assertEqual(type(converted), list)
- self.assertListEqual(converted, l)
-
- def test_list_int(self):
- for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]),
- SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0),
- pyarray.array('d', [1.0, 2.0])]:
- vs = VectorSlicer(indices=indices)
- self.assertListEqual(vs.getIndices(), [1, 2])
- self.assertTrue(all([type(v) == int for v in vs.getIndices()]))
- self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"]))
-
- def test_list_float(self):
- b = Bucketizer(splits=[1, 4])
- self.assertEqual(b.getSplits(), [1.0, 4.0])
- self.assertTrue(all([type(v) == float for v in b.getSplits()]))
- self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0]))
-
- def test_list_string(self):
- for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]:
- idx_to_string = IndexToString(labels=labels)
- self.assertListEqual(idx_to_string.getLabels(), ['a', 'b'])
- self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2]))
-
- def test_string(self):
- lr = LogisticRegression()
- for col in ['features', u'features', np.str_('features')]:
- lr.setFeaturesCol(col)
- self.assertEqual(lr.getFeaturesCol(), 'features')
- self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3))
-
- def test_bool(self):
- self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1))
- self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false"))
-
-
-class PipelineTests(PySparkTestCase):
-
- def test_pipeline(self):
- dataset = MockDataset()
- estimator0 = MockEstimator()
- transformer1 = MockTransformer()
- estimator2 = MockEstimator()
- transformer3 = MockTransformer()
- pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3])
- pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
- model0, transformer1, model2, transformer3 = pipeline_model.stages
- self.assertEqual(0, model0.dataset_index)
- self.assertEqual(0, model0.getFake())
- self.assertEqual(1, transformer1.dataset_index)
- self.assertEqual(1, transformer1.getFake())
- self.assertEqual(2, dataset.index)
- self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.")
- self.assertIsNone(transformer3.dataset_index,
- "The last transformer shouldn't be called in fit.")
- dataset = pipeline_model.transform(dataset)
- self.assertEqual(2, model0.dataset_index)
- self.assertEqual(3, transformer1.dataset_index)
- self.assertEqual(4, model2.dataset_index)
- self.assertEqual(5, transformer3.dataset_index)
- self.assertEqual(6, dataset.index)
-
- def test_identity_pipeline(self):
- dataset = MockDataset()
-
- def doTransform(pipeline):
- pipeline_model = pipeline.fit(dataset)
- return pipeline_model.transform(dataset)
- # check that empty pipeline did not perform any transformation
- self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index)
- # check that failure to set stages param will raise KeyError for missing param
- self.assertRaises(KeyError, lambda: doTransform(Pipeline()))
-
-
-class TestParams(HasMaxIter, HasInputCol, HasSeed):
- """
- A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
- """
- @keyword_only
- def __init__(self, seed=None):
- super(TestParams, self).__init__()
- self._setDefault(maxIter=10)
- kwargs = self._input_kwargs
- self.setParams(**kwargs)
-
- @keyword_only
- def setParams(self, seed=None):
- """
- setParams(self, seed=None)
- Sets params for this test.
- """
- kwargs = self._input_kwargs
- return self._set(**kwargs)
-
-
-class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
- """
- A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
- """
- @keyword_only
- def __init__(self, seed=None):
- super(OtherTestParams, self).__init__()
- self._setDefault(maxIter=10)
- kwargs = self._input_kwargs
- self.setParams(**kwargs)
-
- @keyword_only
- def setParams(self, seed=None):
- """
- setParams(self, seed=None)
- Sets params for this test.
- """
- kwargs = self._input_kwargs
- return self._set(**kwargs)
-
-
-class HasThrowableProperty(Params):
-
- def __init__(self):
- super(HasThrowableProperty, self).__init__()
- self.p = Param(self, "none", "empty param")
-
- @property
- def test_property(self):
- raise RuntimeError("Test property to raise error when invoked")
-
-
-class ParamTests(SparkSessionTestCase):
-
- def test_copy_new_parent(self):
- testParams = TestParams()
- # Copying an instantiated param should fail
- with self.assertRaises(ValueError):
- testParams.maxIter._copy_new_parent(testParams)
- # Copying a dummy param should succeed
- TestParams.maxIter._copy_new_parent(testParams)
- maxIter = testParams.maxIter
- self.assertEqual(maxIter.name, "maxIter")
- self.assertEqual(maxIter.doc, "max number of iterations (>= 0).")
- self.assertTrue(maxIter.parent == testParams.uid)
-
- def test_param(self):
- testParams = TestParams()
- maxIter = testParams.maxIter
- self.assertEqual(maxIter.name, "maxIter")
- self.assertEqual(maxIter.doc, "max number of iterations (>= 0).")
- self.assertTrue(maxIter.parent == testParams.uid)
-
- def test_hasparam(self):
- testParams = TestParams()
- self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
- self.assertFalse(testParams.hasParam("notAParameter"))
- self.assertTrue(testParams.hasParam(u"maxIter"))
-
- def test_resolveparam(self):
- testParams = TestParams()
- self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter)
- self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter)
-
- self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter)
- if sys.version_info[0] >= 3:
- # In Python 3, it is allowed to get/set attributes with non-ascii characters.
- e_cls = AttributeError
- else:
- e_cls = UnicodeEncodeError
- self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아"))
-
- def test_params(self):
- testParams = TestParams()
- maxIter = testParams.maxIter
- inputCol = testParams.inputCol
- seed = testParams.seed
-
- params = testParams.params
- self.assertEqual(params, [inputCol, maxIter, seed])
-
- self.assertTrue(testParams.hasParam(maxIter.name))
- self.assertTrue(testParams.hasDefault(maxIter))
- self.assertFalse(testParams.isSet(maxIter))
- self.assertTrue(testParams.isDefined(maxIter))
- self.assertEqual(testParams.getMaxIter(), 10)
- testParams.setMaxIter(100)
- self.assertTrue(testParams.isSet(maxIter))
- self.assertEqual(testParams.getMaxIter(), 100)
-
- self.assertTrue(testParams.hasParam(inputCol.name))
- self.assertFalse(testParams.hasDefault(inputCol))
- self.assertFalse(testParams.isSet(inputCol))
- self.assertFalse(testParams.isDefined(inputCol))
- with self.assertRaises(KeyError):
- testParams.getInputCol()
-
- otherParam = Param(Params._dummy(), "otherParam", "Parameter used to test that " +
- "set raises an error for a non-member parameter.",
- typeConverter=TypeConverters.toString)
- with self.assertRaises(ValueError):
- testParams.set(otherParam, "value")
-
- # Since the default is normally random, set it to a known number for debug str
- testParams._setDefault(seed=41)
- testParams.setSeed(43)
-
- self.assertEqual(
- testParams.explainParams(),
- "\n".join(["inputCol: input column name. (undefined)",
- "maxIter: max number of iterations (>= 0). (default: 10, current: 100)",
- "seed: random seed. (default: 41, current: 43)"]))
-
- def test_kmeans_param(self):
- algo = KMeans()
- self.assertEqual(algo.getInitMode(), "k-means||")
- algo.setK(10)
- self.assertEqual(algo.getK(), 10)
- algo.setInitSteps(10)
- self.assertEqual(algo.getInitSteps(), 10)
- self.assertEqual(algo.getDistanceMeasure(), "euclidean")
- algo.setDistanceMeasure("cosine")
- self.assertEqual(algo.getDistanceMeasure(), "cosine")
-
- def test_hasseed(self):
- noSeedSpecd = TestParams()
- withSeedSpecd = TestParams(seed=42)
- other = OtherTestParams()
- # Check that we no longer use 42 as the magic number
- self.assertNotEqual(noSeedSpecd.getSeed(), 42)
- origSeed = noSeedSpecd.getSeed()
- # Check that we only compute the seed once
- self.assertEqual(noSeedSpecd.getSeed(), origSeed)
- # Check that a specified seed is honored
- self.assertEqual(withSeedSpecd.getSeed(), 42)
- # Check that a different class has a different seed
- self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed())
-
- def test_param_property_error(self):
- param_store = HasThrowableProperty()
- self.assertRaises(RuntimeError, lambda: param_store.test_property)
- params = param_store.params # should not invoke the property 'test_property'
- self.assertEqual(len(params), 1)
-
- def test_word2vec_param(self):
- model = Word2Vec().setWindowSize(6)
- # Check windowSize is set properly
- self.assertEqual(model.getWindowSize(), 6)
-
- def test_copy_param_extras(self):
- tp = TestParams(seed=42)
- extra = {tp.getParam(TestParams.inputCol.name): "copy_input"}
- tp_copy = tp.copy(extra=extra)
- self.assertEqual(tp.uid, tp_copy.uid)
- self.assertEqual(tp.params, tp_copy.params)
- for k, v in extra.items():
- self.assertTrue(tp_copy.isDefined(k))
- self.assertEqual(tp_copy.getOrDefault(k), v)
- copied_no_extra = {}
- for k, v in tp_copy._paramMap.items():
- if k not in extra:
- copied_no_extra[k] = v
- self.assertEqual(tp._paramMap, copied_no_extra)
- self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap)
-
- def test_logistic_regression_check_thresholds(self):
- self.assertIsInstance(
- LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]),
- LogisticRegression
- )
-
- self.assertRaisesRegexp(
- ValueError,
- "Logistic Regression getThreshold found inconsistent.*$",
- LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
- )
-
- def test_preserve_set_state(self):
- dataset = self.spark.createDataFrame([(0.5,)], ["data"])
- binarizer = Binarizer(inputCol="data")
- self.assertFalse(binarizer.isSet("threshold"))
- binarizer.transform(dataset)
- binarizer._transfer_params_from_java()
- self.assertFalse(binarizer.isSet("threshold"),
- "Params not explicitly set should remain unset after transform")
-
- def test_default_params_transferred(self):
- dataset = self.spark.createDataFrame([(0.5,)], ["data"])
- binarizer = Binarizer(inputCol="data")
- # intentionally change the pyspark default, but don't set it
- binarizer._defaultParamMap[binarizer.outputCol] = "my_default"
- result = binarizer.transform(dataset).select("my_default").collect()
- self.assertFalse(binarizer.isSet(binarizer.outputCol))
- self.assertEqual(result[0][0], 1.0)
-
- @staticmethod
- def check_params(test_self, py_stage, check_params_exist=True):
- """
- Checks common requirements for Params.params:
- - set of params exist in Java and Python and are ordered by names
- - param parent has the same UID as the object's UID
- - default param value from Java matches value in Python
- - optionally check if all params from Java also exist in Python
- """
- py_stage_str = "%s %s" % (type(py_stage), py_stage)
- if not hasattr(py_stage, "_to_java"):
- return
- java_stage = py_stage._to_java()
- if java_stage is None:
- return
- test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str)
- if check_params_exist:
- param_names = [p.name for p in py_stage.params]
- java_params = list(java_stage.params())
- java_param_names = [jp.name() for jp in java_params]
- test_self.assertEqual(
- param_names, sorted(java_param_names),
- "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s"
- % (py_stage_str, java_param_names, param_names))
- for p in py_stage.params:
- test_self.assertEqual(p.parent, py_stage.uid)
- java_param = java_stage.getParam(p.name)
- py_has_default = py_stage.hasDefault(p)
- java_has_default = java_stage.hasDefault(java_param)
- test_self.assertEqual(py_has_default, java_has_default,
- "Default value mismatch of param %s for Params %s"
- % (p.name, str(py_stage)))
- if py_has_default:
- if p.name == "seed":
- continue # Random seeds between Spark and PySpark are different
- java_default = _java2py(test_self.sc,
- java_stage.clear(java_param).getOrDefault(java_param))
- py_stage._clear(p)
- py_default = py_stage.getOrDefault(p)
- # equality test for NaN is always False
- if isinstance(java_default, float) and np.isnan(java_default):
- java_default = "NaN"
- py_default = "NaN" if np.isnan(py_default) else "not NaN"
- test_self.assertEqual(
- java_default, py_default,
- "Java default %s != python default %s of param %s for Params %s"
- % (str(java_default), str(py_default), p.name, str(py_stage)))
-
-
-class EvaluatorTests(SparkSessionTestCase):
-
- def test_java_params(self):
- """
- This tests a bug fixed by SPARK-18274 which causes multiple copies
- of a Params instance in Python to be linked to the same Java instance.
- """
- evaluator = RegressionEvaluator(metricName="r2")
- df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)])
- evaluator.evaluate(df)
- self.assertEqual(evaluator._java_obj.getMetricName(), "r2")
- evaluatorCopy = evaluator.copy({evaluator.metricName: "mae"})
- evaluator.evaluate(df)
- evaluatorCopy.evaluate(df)
- self.assertEqual(evaluator._java_obj.getMetricName(), "r2")
- self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae")
-
- def test_clustering_evaluator_with_cosine_distance(self):
- featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]),
- [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0),
- ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)])
- dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"])
- evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine")
- self.assertEqual(evaluator.getDistanceMeasure(), "cosine")
- self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5))
-
-
-class FeatureTests(SparkSessionTestCase):
-
- def test_binarizer(self):
- b0 = Binarizer()
- self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold])
- self.assertTrue(all([~b0.isSet(p) for p in b0.params]))
- self.assertTrue(b0.hasDefault(b0.threshold))
- self.assertEqual(b0.getThreshold(), 0.0)
- b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0)
- self.assertTrue(all([b0.isSet(p) for p in b0.params]))
- self.assertEqual(b0.getThreshold(), 1.0)
- self.assertEqual(b0.getInputCol(), "input")
- self.assertEqual(b0.getOutputCol(), "output")
-
- b0c = b0.copy({b0.threshold: 2.0})
- self.assertEqual(b0c.uid, b0.uid)
- self.assertListEqual(b0c.params, b0.params)
- self.assertEqual(b0c.getThreshold(), 2.0)
-
- b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output")
- self.assertNotEqual(b1.uid, b0.uid)
- self.assertEqual(b1.getThreshold(), 2.0)
- self.assertEqual(b1.getInputCol(), "input")
- self.assertEqual(b1.getOutputCol(), "output")
-
- def test_idf(self):
- dataset = self.spark.createDataFrame([
- (DenseVector([1.0, 2.0]),),
- (DenseVector([0.0, 1.0]),),
- (DenseVector([3.0, 0.2]),)], ["tf"])
- idf0 = IDF(inputCol="tf")
- self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol])
- idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"})
- self.assertEqual(idf0m.uid, idf0.uid,
- "Model should inherit the UID from its parent estimator.")
- output = idf0m.transform(dataset)
- self.assertIsNotNone(output.head().idf)
- # Test that parameters transferred to Python Model
- ParamTests.check_params(self, idf0m)
-
- def test_ngram(self):
- dataset = self.spark.createDataFrame([
- Row(input=["a", "b", "c", "d", "e"])])
- ngram0 = NGram(n=4, inputCol="input", outputCol="output")
- self.assertEqual(ngram0.getN(), 4)
- self.assertEqual(ngram0.getInputCol(), "input")
- self.assertEqual(ngram0.getOutputCol(), "output")
- transformedDF = ngram0.transform(dataset)
- self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"])
-
- def test_stopwordsremover(self):
- dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
- stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
- # Default
- self.assertEqual(stopWordRemover.getInputCol(), "input")
- transformedDF = stopWordRemover.transform(dataset)
- self.assertEqual(transformedDF.head().output, ["panda"])
- self.assertEqual(type(stopWordRemover.getStopWords()), list)
- self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring))
- # Custom
- stopwords = ["panda"]
- stopWordRemover.setStopWords(stopwords)
- self.assertEqual(stopWordRemover.getInputCol(), "input")
- self.assertEqual(stopWordRemover.getStopWords(), stopwords)
- transformedDF = stopWordRemover.transform(dataset)
- self.assertEqual(transformedDF.head().output, ["a"])
- # with language selection
- stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
- dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
- stopWordRemover.setStopWords(stopwords)
- self.assertEqual(stopWordRemover.getStopWords(), stopwords)
- transformedDF = stopWordRemover.transform(dataset)
- self.assertEqual(transformedDF.head().output, [])
- # with locale
- stopwords = ["BELKİ"]
- dataset = self.spark.createDataFrame([Row(input=["belki"])])
- stopWordRemover.setStopWords(stopwords).setLocale("tr")
- self.assertEqual(stopWordRemover.getStopWords(), stopwords)
- transformedDF = stopWordRemover.transform(dataset)
- self.assertEqual(transformedDF.head().output, [])
-
- def test_count_vectorizer_with_binary(self):
- dataset = self.spark.createDataFrame([
- (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
- (1, "a a".split(' '), SparseVector(3, {0: 1.0}),),
- (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
- (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"])
- cv = CountVectorizer(binary=True, inputCol="words", outputCol="features")
- model = cv.fit(dataset)
-
- transformedList = model.transform(dataset).select("features", "expected").collect()
-
- for r in transformedList:
- feature, expected = r
- self.assertEqual(feature, expected)
-
- def test_count_vectorizer_with_maxDF(self):
- dataset = self.spark.createDataFrame([
- (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
- (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
- (2, "a b".split(' '), SparseVector(3, {0: 1.0}),),
- (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"])
- cv = CountVectorizer(inputCol="words", outputCol="features")
- model1 = cv.setMaxDF(3).fit(dataset)
- self.assertEqual(model1.vocabulary, ['b', 'c', 'd'])
-
- transformedList1 = model1.transform(dataset).select("features", "expected").collect()
-
- for r in transformedList1:
- feature, expected = r
- self.assertEqual(feature, expected)
-
- model2 = cv.setMaxDF(0.75).fit(dataset)
- self.assertEqual(model2.vocabulary, ['b', 'c', 'd'])
-
- transformedList2 = model2.transform(dataset).select("features", "expected").collect()
-
- for r in transformedList2:
- feature, expected = r
- self.assertEqual(feature, expected)
-
- def test_count_vectorizer_from_vocab(self):
- model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words",
- outputCol="features", minTF=2)
- self.assertEqual(model.vocabulary, ["a", "b", "c"])
- self.assertEqual(model.getMinTF(), 2)
-
- dataset = self.spark.createDataFrame([
- (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),),
- (1, "a a".split(' '), SparseVector(3, {0: 2.0}),),
- (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"])
-
- transformed_list = model.transform(dataset).select("features", "expected").collect()
-
- for r in transformed_list:
- feature, expected = r
- self.assertEqual(feature, expected)
-
- # Test an empty vocabulary
- with QuietTest(self.sc):
- with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"):
- CountVectorizerModel.from_vocabulary([], inputCol="words")
-
- # Test model with default settings can transform
- model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words")
- transformed_list = model_default.transform(dataset)\
- .select(model_default.getOrDefault(model_default.outputCol)).collect()
- self.assertEqual(len(transformed_list), 3)
-
- def test_rformula_force_index_label(self):
- df = self.spark.createDataFrame([
- (1.0, 1.0, "a"),
- (0.0, 2.0, "b"),
- (1.0, 0.0, "a")], ["y", "x", "s"])
- # Does not index label by default since it's numeric type.
- rf = RFormula(formula="y ~ x + s")
- model = rf.fit(df)
- transformedDF = model.transform(df)
- self.assertEqual(transformedDF.head().label, 1.0)
- # Force to index label.
- rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True)
- model2 = rf2.fit(df)
- transformedDF2 = model2.transform(df)
- self.assertEqual(transformedDF2.head().label, 0.0)
-
- def test_rformula_string_indexer_order_type(self):
- df = self.spark.createDataFrame([
- (1.0, 1.0, "a"),
- (0.0, 2.0, "b"),
- (1.0, 0.0, "a")], ["y", "x", "s"])
- rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc")
- self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc')
- transformedDF = rf.fit(df).transform(df)
- observed = transformedDF.select("features").collect()
- expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]]
- for i in range(0, len(expected)):
- self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
-
- def test_string_indexer_handle_invalid(self):
- df = self.spark.createDataFrame([
- (0, "a"),
- (1, "d"),
- (2, None)], ["id", "label"])
-
- si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep",
- stringOrderType="alphabetAsc")
- model1 = si1.fit(df)
- td1 = model1.transform(df)
- actual1 = td1.select("id", "indexed").collect()
- expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)]
- self.assertEqual(actual1, expected1)
-
- si2 = si1.setHandleInvalid("skip")
- model2 = si2.fit(df)
- td2 = model2.transform(df)
- actual2 = td2.select("id", "indexed").collect()
- expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
- self.assertEqual(actual2, expected2)
-
- def test_string_indexer_from_labels(self):
- model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label",
- outputCol="indexed", handleInvalid="keep")
- self.assertEqual(model.labels, ["a", "b", "c"])
-
- df1 = self.spark.createDataFrame([
- (0, "a"),
- (1, "c"),
- (2, None),
- (3, "b"),
- (4, "b")], ["id", "label"])
-
- result1 = model.transform(df1)
- actual1 = result1.select("id", "indexed").collect()
- expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0),
- Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)]
- self.assertEqual(actual1, expected1)
-
- model_empty_labels = StringIndexerModel.from_labels(
- [], inputCol="label", outputCol="indexed", handleInvalid="keep")
- actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect()
- expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0),
- Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)]
- self.assertEqual(actual2, expected2)
-
- # Test model with default settings can transform
- model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label")
- df2 = self.spark.createDataFrame([
- (0, "a"),
- (1, "c"),
- (2, "b"),
- (3, "b"),
- (4, "b")], ["id", "label"])
- transformed_list = model_default.transform(df2)\
- .select(model_default.getOrDefault(model_default.outputCol)).collect()
- self.assertEqual(len(transformed_list), 5)
-
- def test_vector_size_hint(self):
- df = self.spark.createDataFrame(
- [(0, Vectors.dense([0.0, 10.0, 0.5])),
- (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])),
- (2, Vectors.dense([2.0, 12.0]))],
- ["id", "vector"])
-
- sizeHint = VectorSizeHint(
- inputCol="vector",
- handleInvalid="skip")
- sizeHint.setSize(3)
- self.assertEqual(sizeHint.getSize(), 3)
-
- output = sizeHint.transform(df).head().vector
- expected = DenseVector([0.0, 10.0, 0.5])
- self.assertEqual(output, expected)
-
-
-class HasInducedError(Params):
-
- def __init__(self):
- super(HasInducedError, self).__init__()
- self.inducedError = Param(self, "inducedError",
- "Uniformly-distributed error added to feature")
-
- def getInducedError(self):
- return self.getOrDefault(self.inducedError)
-
-
-class InducedErrorModel(Model, HasInducedError):
-
- def __init__(self):
- super(InducedErrorModel, self).__init__()
-
- def _transform(self, dataset):
- return dataset.withColumn("prediction",
- dataset.feature + (rand(0) * self.getInducedError()))
-
-
-class InducedErrorEstimator(Estimator, HasInducedError):
-
- def __init__(self, inducedError=1.0):
- super(InducedErrorEstimator, self).__init__()
- self._set(inducedError=inducedError)
-
- def _fit(self, dataset):
- model = InducedErrorModel()
- self._copyValues(model)
- return model
-
-
-class CrossValidatorTests(SparkSessionTestCase):
-
- def test_copy(self):
- dataset = self.spark.createDataFrame([
- (10, 10.0),
- (50, 50.0),
- (100, 100.0),
- (500, 500.0)] * 10,
- ["feature", "label"])
-
- iee = InducedErrorEstimator()
- evaluator = RegressionEvaluator(metricName="rmse")
-
- grid = (ParamGridBuilder()
- .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
- .build())
- cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
- cvCopied = cv.copy()
- self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid)
-
- cvModel = cv.fit(dataset)
- cvModelCopied = cvModel.copy()
- for index in range(len(cvModel.avgMetrics)):
- self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index])
- < 0.0001)
-
- def test_fit_minimize_metric(self):
- dataset = self.spark.createDataFrame([
- (10, 10.0),
- (50, 50.0),
- (100, 100.0),
- (500, 500.0)] * 10,
- ["feature", "label"])
-
- iee = InducedErrorEstimator()
- evaluator = RegressionEvaluator(metricName="rmse")
-
- grid = (ParamGridBuilder()
- .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
- .build())
- cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
- cvModel = cv.fit(dataset)
- bestModel = cvModel.bestModel
- bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
-
- self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
- "Best model should have zero induced error")
- self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
-
- def test_fit_maximize_metric(self):
- dataset = self.spark.createDataFrame([
- (10, 10.0),
- (50, 50.0),
- (100, 100.0),
- (500, 500.0)] * 10,
- ["feature", "label"])
-
- iee = InducedErrorEstimator()
- evaluator = RegressionEvaluator(metricName="r2")
-
- grid = (ParamGridBuilder()
- .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
- .build())
- cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
- cvModel = cv.fit(dataset)
- bestModel = cvModel.bestModel
- bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
-
- self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
- "Best model should have zero induced error")
- self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
-
- def test_param_grid_type_coercion(self):
- lr = LogisticRegression(maxIter=10)
- paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build()
- for param in paramGrid:
- for v in param.values():
- assert(type(v) == float)
-
- def test_save_load_trained_model(self):
- # This tests saving and loading the trained model only.
- # Save/load for CrossValidator will be added later: SPARK-13786
- temp_path = tempfile.mkdtemp()
- dataset = self.spark.createDataFrame(
- [(Vectors.dense([0.0]), 0.0),
- (Vectors.dense([0.4]), 1.0),
- (Vectors.dense([0.5]), 0.0),
- (Vectors.dense([0.6]), 1.0),
- (Vectors.dense([1.0]), 1.0)] * 10,
- ["features", "label"])
- lr = LogisticRegression()
- grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
- evaluator = BinaryClassificationEvaluator()
- cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
- cvModel = cv.fit(dataset)
- lrModel = cvModel.bestModel
-
- cvModelPath = temp_path + "/cvModel"
- lrModel.save(cvModelPath)
- loadedLrModel = LogisticRegressionModel.load(cvModelPath)
- self.assertEqual(loadedLrModel.uid, lrModel.uid)
- self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
-
- def test_save_load_simple_estimator(self):
- temp_path = tempfile.mkdtemp()
- dataset = self.spark.createDataFrame(
- [(Vectors.dense([0.0]), 0.0),
- (Vectors.dense([0.4]), 1.0),
- (Vectors.dense([0.5]), 0.0),
- (Vectors.dense([0.6]), 1.0),
- (Vectors.dense([1.0]), 1.0)] * 10,
- ["features", "label"])
-
- lr = LogisticRegression()
- grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
- evaluator = BinaryClassificationEvaluator()
-
- # test save/load of CrossValidator
- cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
- cvModel = cv.fit(dataset)
- cvPath = temp_path + "/cv"
- cv.save(cvPath)
- loadedCV = CrossValidator.load(cvPath)
- self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
- self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
- self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
-
- # test save/load of CrossValidatorModel
- cvModelPath = temp_path + "/cvModel"
- cvModel.save(cvModelPath)
- loadedModel = CrossValidatorModel.load(cvModelPath)
- self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
-
- def test_parallel_evaluation(self):
- dataset = self.spark.createDataFrame(
- [(Vectors.dense([0.0]), 0.0),
- (Vectors.dense([0.4]), 1.0),
- (Vectors.dense([0.5]), 0.0),
- (Vectors.dense([0.6]), 1.0),
- (Vectors.dense([1.0]), 1.0)] * 10,
- ["features", "label"])
-
- lr = LogisticRegression()
- grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build()
- evaluator = BinaryClassificationEvaluator()
-
- # test save/load of CrossValidator
- cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
- cv.setParallelism(1)
- cvSerialModel = cv.fit(dataset)
- cv.setParallelism(2)
- cvParallelModel = cv.fit(dataset)
- self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics)
-
- def test_expose_sub_models(self):
- temp_path = tempfile.mkdtemp()
- dataset = self.spark.createDataFrame(
- [(Vectors.dense([0.0]), 0.0),
- (Vectors.dense([0.4]), 1.0),
- (Vectors.dense([0.5]), 0.0),
- (Vectors.dense([0.6]), 1.0),
- (Vectors.dense([1.0]), 1.0)] * 10,
- ["features", "label"])
-
- lr = LogisticRegression()
- grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
- evaluator = BinaryClassificationEvaluator()
-
- numFolds = 3
- cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
- numFolds=numFolds, collectSubModels=True)
-
- def checkSubModels(subModels):
- self.assertEqual(len(subModels), numFolds)
- for i in range(numFolds):
- self.assertEqual(len(subModels[i]), len(grid))
-
- cvModel = cv.fit(dataset)
- checkSubModels(cvModel.subModels)
-
- # Test the default value for option "persistSubModel" to be "true"
- testSubPath = temp_path + "/testCrossValidatorSubModels"
- savingPathWithSubModels = testSubPath + "cvModel3"
- cvModel.save(savingPathWithSubModels)
- cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
- checkSubModels(cvModel3.subModels)
- cvModel4 = cvModel3.copy()
- checkSubModels(cvModel4.subModels)
-
- savingPathWithoutSubModels = testSubPath + "cvModel2"
- cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels)
- cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
- self.assertEqual(cvModel2.subModels, None)
-
- for i in range(numFolds):
- for j in range(len(grid)):
- self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid)
-
- def test_save_load_nested_estimator(self):
- temp_path = tempfile.mkdtemp()
- dataset = self.spark.createDataFrame(
- [(Vectors.dense([0.0]), 0.0),
- (Vectors.dense([0.4]), 1.0),
- (Vectors.dense([0.5]), 0.0),
- (Vectors.dense([0.6]), 1.0),
- (Vectors.dense([1.0]), 1.0)] * 10,
- ["features", "label"])
-
- ova = OneVsRest(classifier=LogisticRegression())
- lr1 = LogisticRegression().setMaxIter(100)
- lr2 = LogisticRegression().setMaxIter(150)
- grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
- evaluator = MulticlassClassificationEvaluator()
-
- # test save/load of CrossValidator
- cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)
- cvModel = cv.fit(dataset)
- cvPath = temp_path + "/cv"
- cv.save(cvPath)
- loadedCV = CrossValidator.load(cvPath)
- self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
- self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
-
- originalParamMap = cv.getEstimatorParamMaps()
- loadedParamMap = loadedCV.getEstimatorParamMaps()
- for i, param in enumerate(loadedParamMap):
- for p in param:
- if p.name == "classifier":
- self.assertEqual(param[p].uid, originalParamMap[i][p].uid)
- else:
- self.assertEqual(param[p], originalParamMap[i][p])
-
- # test save/load of CrossValidatorModel
- cvModelPath = temp_path + "/cvModel"
- cvModel.save(cvModelPath)
- loadedModel = CrossValidatorModel.load(cvModelPath)
- self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
-
-
-class TrainValidationSplitTests(SparkSessionTestCase):
-
- def test_fit_minimize_metric(self):
- dataset = self.spark.createDataFrame([
- (10, 10.0),
- (50, 50.0),
- (100, 100.0),
- (500, 500.0)] * 10,
- ["feature", "label"])
-
- iee = InducedErrorEstimator()
- evaluator = RegressionEvaluator(metricName="rmse")
-
- grid = ParamGridBuilder() \
- .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \
- .build()
- tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
- tvsModel = tvs.fit(dataset)
- bestModel = tvsModel.bestModel
- bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
- validationMetrics = tvsModel.validationMetrics
-
- self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
- "Best model should have zero induced error")
- self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
- self.assertEqual(len(grid), len(validationMetrics),
- "validationMetrics has the same size of grid parameter")
- self.assertEqual(0.0, min(validationMetrics))
-
- def test_fit_maximize_metric(self):
- dataset = self.spark.createDataFrame([
- (10, 10.0),
- (50, 50.0),
- (100, 100.0),
- (500, 500.0)] * 10,
- ["feature", "label"])
-
- iee = InducedErrorEstimator()
- evaluator = RegressionEvaluator(metricName="r2")
-
- grid = ParamGridBuilder() \
- .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \
- .build()
- tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
- tvsModel = tvs.fit(dataset)
- bestModel = tvsModel.bestModel
- bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
- validationMetrics = tvsModel.validationMetrics
-
- self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
- "Best model should have zero induced error")
- self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
- self.assertEqual(len(grid), len(validationMetrics),
- "validationMetrics has the same size of grid parameter")
- self.assertEqual(1.0, max(validationMetrics))
-
- def test_save_load_trained_model(self):
- # This tests saving and loading the trained model only.
- # Save/load for TrainValidationSplit will be added later: SPARK-13786
- temp_path = tempfile.mkdtemp()
- dataset = self.spark.createDataFrame(
- [(Vectors.dense([0.0]), 0.0),
- (Vectors.dense([0.4]), 1.0),
- (Vectors.dense([0.5]), 0.0),
- (Vectors.dense([0.6]), 1.0),
- (Vectors.dense([1.0]), 1.0)] * 10,
- ["features", "label"])
- lr = LogisticRegression()
- grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
- evaluator = BinaryClassificationEvaluator()
- tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
- tvsModel = tvs.fit(dataset)
- lrModel = tvsModel.bestModel
-
- tvsModelPath = temp_path + "/tvsModel"
- lrModel.save(tvsModelPath)
- loadedLrModel = LogisticRegressionModel.load(tvsModelPath)
- self.assertEqual(loadedLrModel.uid, lrModel.uid)
- self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
-
- def test_save_load_simple_estimator(self):
- # This tests saving and loading the trained model only.
- # Save/load for TrainValidationSplit will be added later: SPARK-13786
- temp_path = tempfile.mkdtemp()
- dataset = self.spark.createDataFrame(
- [(Vectors.dense([0.0]), 0.0),
- (Vectors.dense([0.4]), 1.0),
- (Vectors.dense([0.5]), 0.0),
- (Vectors.dense([0.6]), 1.0),
- (Vectors.dense([1.0]), 1.0)] * 10,
- ["features", "label"])
- lr = LogisticRegression()
- grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
- evaluator = BinaryClassificationEvaluator()
- tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
- tvsModel = tvs.fit(dataset)
-
- tvsPath = temp_path + "/tvs"
- tvs.save(tvsPath)
- loadedTvs = TrainValidationSplit.load(tvsPath)
- self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
- self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
- self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
-
- tvsModelPath = temp_path + "/tvsModel"
- tvsModel.save(tvsModelPath)
- loadedModel = TrainValidationSplitModel.load(tvsModelPath)
- self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
-
- def test_parallel_evaluation(self):
- dataset = self.spark.createDataFrame(
- [(Vectors.dense([0.0]), 0.0),
- (Vectors.dense([0.4]), 1.0),
- (Vectors.dense([0.5]), 0.0),
- (Vectors.dense([0.6]), 1.0),
- (Vectors.dense([1.0]), 1.0)] * 10,
- ["features", "label"])
- lr = LogisticRegression()
- grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build()
- evaluator = BinaryClassificationEvaluator()
- tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
- tvs.setParallelism(1)
- tvsSerialModel = tvs.fit(dataset)
- tvs.setParallelism(2)
- tvsParallelModel = tvs.fit(dataset)
- self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics)
-
- def test_expose_sub_models(self):
- temp_path = tempfile.mkdtemp()
- dataset = self.spark.createDataFrame(
- [(Vectors.dense([0.0]), 0.0),
- (Vectors.dense([0.4]), 1.0),
- (Vectors.dense([0.5]), 0.0),
- (Vectors.dense([0.6]), 1.0),
- (Vectors.dense([1.0]), 1.0)] * 10,
- ["features", "label"])
- lr = LogisticRegression()
- grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
- evaluator = BinaryClassificationEvaluator()
- tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
- collectSubModels=True)
- tvsModel = tvs.fit(dataset)
- self.assertEqual(len(tvsModel.subModels), len(grid))
-
- # Test the default value for option "persistSubModel" to be "true"
- testSubPath = temp_path + "/testTrainValidationSplitSubModels"
- savingPathWithSubModels = testSubPath + "cvModel3"
- tvsModel.save(savingPathWithSubModels)
- tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels)
- self.assertEqual(len(tvsModel3.subModels), len(grid))
- tvsModel4 = tvsModel3.copy()
- self.assertEqual(len(tvsModel4.subModels), len(grid))
-
- savingPathWithoutSubModels = testSubPath + "cvModel2"
- tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels)
- tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
- self.assertEqual(tvsModel2.subModels, None)
-
- for i in range(len(grid)):
- self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid)
-
- def test_save_load_nested_estimator(self):
- # This tests saving and loading the trained model only.
- # Save/load for TrainValidationSplit will be added later: SPARK-13786
- temp_path = tempfile.mkdtemp()
- dataset = self.spark.createDataFrame(
- [(Vectors.dense([0.0]), 0.0),
- (Vectors.dense([0.4]), 1.0),
- (Vectors.dense([0.5]), 0.0),
- (Vectors.dense([0.6]), 1.0),
- (Vectors.dense([1.0]), 1.0)] * 10,
- ["features", "label"])
- ova = OneVsRest(classifier=LogisticRegression())
- lr1 = LogisticRegression().setMaxIter(100)
- lr2 = LogisticRegression().setMaxIter(150)
- grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
- evaluator = MulticlassClassificationEvaluator()
-
- tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)
- tvsModel = tvs.fit(dataset)
- tvsPath = temp_path + "/tvs"
- tvs.save(tvsPath)
- loadedTvs = TrainValidationSplit.load(tvsPath)
- self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
- self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
-
- originalParamMap = tvs.getEstimatorParamMaps()
- loadedParamMap = loadedTvs.getEstimatorParamMaps()
- for i, param in enumerate(loadedParamMap):
- for p in param:
- if p.name == "classifier":
- self.assertEqual(param[p].uid, originalParamMap[i][p].uid)
- else:
- self.assertEqual(param[p], originalParamMap[i][p])
-
- tvsModelPath = temp_path + "/tvsModel"
- tvsModel.save(tvsModelPath)
- loadedModel = TrainValidationSplitModel.load(tvsModelPath)
- self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
-
- def test_copy(self):
- dataset = self.spark.createDataFrame([
- (10, 10.0),
- (50, 50.0),
- (100, 100.0),
- (500, 500.0)] * 10,
- ["feature", "label"])
-
- iee = InducedErrorEstimator()
- evaluator = RegressionEvaluator(metricName="r2")
-
- grid = ParamGridBuilder() \
- .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \
- .build()
- tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
- tvsModel = tvs.fit(dataset)
- tvsCopied = tvs.copy()
- tvsModelCopied = tvsModel.copy()
-
- self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid,
- "Copied TrainValidationSplit has the same uid of Estimator")
-
- self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid)
- self.assertEqual(len(tvsModel.validationMetrics),
- len(tvsModelCopied.validationMetrics),
- "Copied validationMetrics has the same size of the original")
- for index in range(len(tvsModel.validationMetrics)):
- self.assertEqual(tvsModel.validationMetrics[index],
- tvsModelCopied.validationMetrics[index])
-
-
-class PersistenceTest(SparkSessionTestCase):
-
- def test_linear_regression(self):
- lr = LinearRegression(maxIter=1)
- path = tempfile.mkdtemp()
- lr_path = path + "/lr"
- lr.save(lr_path)
- lr2 = LinearRegression.load(lr_path)
- self.assertEqual(lr.uid, lr2.uid)
- self.assertEqual(type(lr.uid), type(lr2.uid))
- self.assertEqual(lr2.uid, lr2.maxIter.parent,
- "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)"
- % (lr2.uid, lr2.maxIter.parent))
- self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
- "Loaded LinearRegression instance default params did not match " +
- "original defaults")
- try:
- rmtree(path)
- except OSError:
- pass
-
- def test_linear_regression_pmml_basic(self):
- # Most of the validation is done in the Scala side, here we just check
- # that we output text rather than parquet (e.g. that the format flag
- # was respected).
- df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
- (0.0, 2.0, Vectors.sparse(1, [], []))],
- ["label", "weight", "features"])
- lr = LinearRegression(maxIter=1)
- model = lr.fit(df)
- path = tempfile.mkdtemp()
- lr_path = path + "/lr-pmml"
- model.write().format("pmml").save(lr_path)
- pmml_text_list = self.sc.textFile(lr_path).collect()
- pmml_text = "\n".join(pmml_text_list)
- self.assertIn("Apache Spark", pmml_text)
- self.assertIn("PMML", pmml_text)
-
- def test_logistic_regression(self):
- lr = LogisticRegression(maxIter=1)
- path = tempfile.mkdtemp()
- lr_path = path + "/logreg"
- lr.save(lr_path)
- lr2 = LogisticRegression.load(lr_path)
- self.assertEqual(lr2.uid, lr2.maxIter.parent,
- "Loaded LogisticRegression instance uid (%s) "
- "did not match Param's uid (%s)"
- % (lr2.uid, lr2.maxIter.parent))
- self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
- "Loaded LogisticRegression instance default params did not match " +
- "original defaults")
- try:
- rmtree(path)
- except OSError:
- pass
-
- def _compare_params(self, m1, m2, param):
- """
- Compare 2 ML Params instances for the given param, and assert both have the same param value
- and parent. The param must be a parameter of m1.
- """
- # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap.
- if m1.isDefined(param):
- paramValue1 = m1.getOrDefault(param)
- paramValue2 = m2.getOrDefault(m2.getParam(param.name))
- if isinstance(paramValue1, Params):
- self._compare_pipelines(paramValue1, paramValue2)
- else:
- self.assertEqual(paramValue1, paramValue2) # for general types param
- # Assert parents are equal
- self.assertEqual(param.parent, m2.getParam(param.name).parent)
- else:
- # If m1 is not defined param, then m2 should not, too. See SPARK-14931.
- self.assertFalse(m2.isDefined(m2.getParam(param.name)))
-
- def _compare_pipelines(self, m1, m2):
- """
- Compare 2 ML types, asserting that they are equivalent.
- This currently supports:
- - basic types
- - Pipeline, PipelineModel
- - OneVsRest, OneVsRestModel
- This checks:
- - uid
- - type
- - Param values and parents
- """
- self.assertEqual(m1.uid, m2.uid)
- self.assertEqual(type(m1), type(m2))
- if isinstance(m1, JavaParams) or isinstance(m1, Transformer):
- self.assertEqual(len(m1.params), len(m2.params))
- for p in m1.params:
- self._compare_params(m1, m2, p)
- elif isinstance(m1, Pipeline):
- self.assertEqual(len(m1.getStages()), len(m2.getStages()))
- for s1, s2 in zip(m1.getStages(), m2.getStages()):
- self._compare_pipelines(s1, s2)
- elif isinstance(m1, PipelineModel):
- self.assertEqual(len(m1.stages), len(m2.stages))
- for s1, s2 in zip(m1.stages, m2.stages):
- self._compare_pipelines(s1, s2)
- elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel):
- for p in m1.params:
- self._compare_params(m1, m2, p)
- if isinstance(m1, OneVsRestModel):
- self.assertEqual(len(m1.models), len(m2.models))
- for x, y in zip(m1.models, m2.models):
- self._compare_pipelines(x, y)
- else:
- raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
-
- def test_pipeline_persistence(self):
- """
- Pipeline[HashingTF, PCA]
- """
- temp_path = tempfile.mkdtemp()
-
- try:
- df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
- tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
- pca = PCA(k=2, inputCol="features", outputCol="pca_features")
- pl = Pipeline(stages=[tf, pca])
- model = pl.fit(df)
-
- pipeline_path = temp_path + "/pipeline"
- pl.save(pipeline_path)
- loaded_pipeline = Pipeline.load(pipeline_path)
- self._compare_pipelines(pl, loaded_pipeline)
-
- model_path = temp_path + "/pipeline-model"
- model.save(model_path)
- loaded_model = PipelineModel.load(model_path)
- self._compare_pipelines(model, loaded_model)
- finally:
- try:
- rmtree(temp_path)
- except OSError:
- pass
-
- def test_nested_pipeline_persistence(self):
- """
- Pipeline[HashingTF, Pipeline[PCA]]
- """
- temp_path = tempfile.mkdtemp()
-
- try:
- df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
- tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
- pca = PCA(k=2, inputCol="features", outputCol="pca_features")
- p0 = Pipeline(stages=[pca])
- pl = Pipeline(stages=[tf, p0])
- model = pl.fit(df)
-
- pipeline_path = temp_path + "/pipeline"
- pl.save(pipeline_path)
- loaded_pipeline = Pipeline.load(pipeline_path)
- self._compare_pipelines(pl, loaded_pipeline)
-
- model_path = temp_path + "/pipeline-model"
- model.save(model_path)
- loaded_model = PipelineModel.load(model_path)
- self._compare_pipelines(model, loaded_model)
- finally:
- try:
- rmtree(temp_path)
- except OSError:
- pass
-
- def test_python_transformer_pipeline_persistence(self):
- """
- Pipeline[MockUnaryTransformer, Binarizer]
- """
- temp_path = tempfile.mkdtemp()
-
- try:
- df = self.spark.range(0, 10).toDF('input')
- tf = MockUnaryTransformer(shiftVal=2)\
- .setInputCol("input").setOutputCol("shiftedInput")
- tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized")
- pl = Pipeline(stages=[tf, tf2])
- model = pl.fit(df)
-
- pipeline_path = temp_path + "/pipeline"
- pl.save(pipeline_path)
- loaded_pipeline = Pipeline.load(pipeline_path)
- self._compare_pipelines(pl, loaded_pipeline)
-
- model_path = temp_path + "/pipeline-model"
- model.save(model_path)
- loaded_model = PipelineModel.load(model_path)
- self._compare_pipelines(model, loaded_model)
- finally:
- try:
- rmtree(temp_path)
- except OSError:
- pass
-
- def test_onevsrest(self):
- temp_path = tempfile.mkdtemp()
- df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
- (1.0, Vectors.sparse(2, [], [])),
- (2.0, Vectors.dense(0.5, 0.5))] * 10,
- ["label", "features"])
- lr = LogisticRegression(maxIter=5, regParam=0.01)
- ovr = OneVsRest(classifier=lr)
- model = ovr.fit(df)
- ovrPath = temp_path + "/ovr"
- ovr.save(ovrPath)
- loadedOvr = OneVsRest.load(ovrPath)
- self._compare_pipelines(ovr, loadedOvr)
- modelPath = temp_path + "/ovrModel"
- model.save(modelPath)
- loadedModel = OneVsRestModel.load(modelPath)
- self._compare_pipelines(model, loadedModel)
-
- def test_decisiontree_classifier(self):
- dt = DecisionTreeClassifier(maxDepth=1)
- path = tempfile.mkdtemp()
- dtc_path = path + "/dtc"
- dt.save(dtc_path)
- dt2 = DecisionTreeClassifier.load(dtc_path)
- self.assertEqual(dt2.uid, dt2.maxDepth.parent,
- "Loaded DecisionTreeClassifier instance uid (%s) "
- "did not match Param's uid (%s)"
- % (dt2.uid, dt2.maxDepth.parent))
- self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
- "Loaded DecisionTreeClassifier instance default params did not match " +
- "original defaults")
- try:
- rmtree(path)
- except OSError:
- pass
-
- def test_decisiontree_regressor(self):
- dt = DecisionTreeRegressor(maxDepth=1)
- path = tempfile.mkdtemp()
- dtr_path = path + "/dtr"
- dt.save(dtr_path)
- dt2 = DecisionTreeClassifier.load(dtr_path)
- self.assertEqual(dt2.uid, dt2.maxDepth.parent,
- "Loaded DecisionTreeRegressor instance uid (%s) "
- "did not match Param's uid (%s)"
- % (dt2.uid, dt2.maxDepth.parent))
- self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
- "Loaded DecisionTreeRegressor instance default params did not match " +
- "original defaults")
- try:
- rmtree(path)
- except OSError:
- pass
-
- def test_default_read_write(self):
- temp_path = tempfile.mkdtemp()
-
- lr = LogisticRegression()
- lr.setMaxIter(50)
- lr.setThreshold(.75)
- writer = DefaultParamsWriter(lr)
-
- savePath = temp_path + "/lr"
- writer.save(savePath)
-
- reader = DefaultParamsReadable.read()
- lr2 = reader.load(savePath)
-
- self.assertEqual(lr.uid, lr2.uid)
- self.assertEqual(lr.extractParamMap(), lr2.extractParamMap())
-
- # test overwrite
- lr.setThreshold(.8)
- writer.overwrite().save(savePath)
-
- reader = DefaultParamsReadable.read()
- lr3 = reader.load(savePath)
-
- self.assertEqual(lr.uid, lr3.uid)
- self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
-
- def test_default_read_write_default_params(self):
- lr = LogisticRegression()
- self.assertFalse(lr.isSet(lr.getParam("threshold")))
-
- lr.setMaxIter(50)
- lr.setThreshold(.75)
-
- # `threshold` is set by user, default param `predictionCol` is not set by user.
- self.assertTrue(lr.isSet(lr.getParam("threshold")))
- self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
- self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
-
- writer = DefaultParamsWriter(lr)
- metadata = json.loads(writer._get_metadata_to_save(lr, self.sc))
- self.assertTrue("defaultParamMap" in metadata)
-
- reader = DefaultParamsReadable.read()
- metadataStr = json.dumps(metadata, separators=[',', ':'])
- loadedMetadata = reader._parseMetaData(metadataStr, )
- reader.getAndSetParams(lr, loadedMetadata)
-
- self.assertTrue(lr.isSet(lr.getParam("threshold")))
- self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
- self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
-
- # manually create metadata without `defaultParamMap` section.
- del metadata['defaultParamMap']
- metadataStr = json.dumps(metadata, separators=[',', ':'])
- loadedMetadata = reader._parseMetaData(metadataStr, )
- with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
- reader.getAndSetParams(lr, loadedMetadata)
-
- # Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
- metadata['sparkVersion'] = '2.3.0'
- metadataStr = json.dumps(metadata, separators=[',', ':'])
- loadedMetadata = reader._parseMetaData(metadataStr, )
- reader.getAndSetParams(lr, loadedMetadata)
-
-
-class LDATest(SparkSessionTestCase):
-
- def _compare(self, m1, m2):
- """
- Temp method for comparing instances.
- TODO: Replace with generic implementation once SPARK-14706 is merged.
- """
- self.assertEqual(m1.uid, m2.uid)
- self.assertEqual(type(m1), type(m2))
- self.assertEqual(len(m1.params), len(m2.params))
- for p in m1.params:
- if m1.isDefined(p):
- self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
- self.assertEqual(p.parent, m2.getParam(p.name).parent)
- if isinstance(m1, LDAModel):
- self.assertEqual(m1.vocabSize(), m2.vocabSize())
- self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix())
-
- def test_persistence(self):
- # Test save/load for LDA, LocalLDAModel, DistributedLDAModel.
- df = self.spark.createDataFrame([
- [1, Vectors.dense([0.0, 1.0])],
- [2, Vectors.sparse(2, {0: 1.0})],
- ], ["id", "features"])
- # Fit model
- lda = LDA(k=2, seed=1, optimizer="em")
- distributedModel = lda.fit(df)
- self.assertTrue(distributedModel.isDistributed())
- localModel = distributedModel.toLocal()
- self.assertFalse(localModel.isDistributed())
- # Define paths
- path = tempfile.mkdtemp()
- lda_path = path + "/lda"
- dist_model_path = path + "/distLDAModel"
- local_model_path = path + "/localLDAModel"
- # Test LDA
- lda.save(lda_path)
- lda2 = LDA.load(lda_path)
- self._compare(lda, lda2)
- # Test DistributedLDAModel
- distributedModel.save(dist_model_path)
- distributedModel2 = DistributedLDAModel.load(dist_model_path)
- self._compare(distributedModel, distributedModel2)
- # Test LocalLDAModel
- localModel.save(local_model_path)
- localModel2 = LocalLDAModel.load(local_model_path)
- self._compare(localModel, localModel2)
- # Clean up
- try:
- rmtree(path)
- except OSError:
- pass
-
-
-class TrainingSummaryTest(SparkSessionTestCase):
-
- def test_linear_regression_summary(self):
- df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
- (0.0, 2.0, Vectors.sparse(1, [], []))],
- ["label", "weight", "features"])
- lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight",
- fitIntercept=False)
- model = lr.fit(df)
- self.assertTrue(model.hasSummary)
- s = model.summary
- # test that api is callable and returns expected types
- self.assertGreater(s.totalIterations, 0)
- self.assertTrue(isinstance(s.predictions, DataFrame))
- self.assertEqual(s.predictionCol, "prediction")
- self.assertEqual(s.labelCol, "label")
- self.assertEqual(s.featuresCol, "features")
- objHist = s.objectiveHistory
- self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
- self.assertAlmostEqual(s.explainedVariance, 0.25, 2)
- self.assertAlmostEqual(s.meanAbsoluteError, 0.0)
- self.assertAlmostEqual(s.meanSquaredError, 0.0)
- self.assertAlmostEqual(s.rootMeanSquaredError, 0.0)
- self.assertAlmostEqual(s.r2, 1.0, 2)
- self.assertAlmostEqual(s.r2adj, 1.0, 2)
- self.assertTrue(isinstance(s.residuals, DataFrame))
- self.assertEqual(s.numInstances, 2)
- self.assertEqual(s.degreesOfFreedom, 1)
- devResiduals = s.devianceResiduals
- self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float))
- coefStdErr = s.coefficientStandardErrors
- self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float))
- tValues = s.tValues
- self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float))
- pValues = s.pValues
- self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float))
- # test evaluation (with training dataset) produces a summary with same values
- # one check is enough to verify a summary is returned
- # The child class LinearRegressionTrainingSummary runs full test
- sameSummary = model.evaluate(df)
- self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance)
-
- def test_glr_summary(self):
- from pyspark.ml.linalg import Vectors
- df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
- (0.0, 2.0, Vectors.sparse(1, [], []))],
- ["label", "weight", "features"])
- glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight",
- fitIntercept=False)
- model = glr.fit(df)
- self.assertTrue(model.hasSummary)
- s = model.summary
- # test that api is callable and returns expected types
- self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS
- self.assertTrue(isinstance(s.predictions, DataFrame))
- self.assertEqual(s.predictionCol, "prediction")
- self.assertEqual(s.numInstances, 2)
- self.assertTrue(isinstance(s.residuals(), DataFrame))
- self.assertTrue(isinstance(s.residuals("pearson"), DataFrame))
- coefStdErr = s.coefficientStandardErrors
- self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float))
- tValues = s.tValues
- self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float))
- pValues = s.pValues
- self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float))
- self.assertEqual(s.degreesOfFreedom, 1)
- self.assertEqual(s.residualDegreeOfFreedom, 1)
- self.assertEqual(s.residualDegreeOfFreedomNull, 2)
- self.assertEqual(s.rank, 1)
- self.assertTrue(isinstance(s.solver, basestring))
- self.assertTrue(isinstance(s.aic, float))
- self.assertTrue(isinstance(s.deviance, float))
- self.assertTrue(isinstance(s.nullDeviance, float))
- self.assertTrue(isinstance(s.dispersion, float))
- # test evaluation (with training dataset) produces a summary with same values
- # one check is enough to verify a summary is returned
- # The child class GeneralizedLinearRegressionTrainingSummary runs full test
- sameSummary = model.evaluate(df)
- self.assertAlmostEqual(sameSummary.deviance, s.deviance)
-
- def test_binary_logistic_regression_summary(self):
- df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
- (0.0, 2.0, Vectors.sparse(1, [], []))],
- ["label", "weight", "features"])
- lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
- model = lr.fit(df)
- self.assertTrue(model.hasSummary)
- s = model.summary
- # test that api is callable and returns expected types
- self.assertTrue(isinstance(s.predictions, DataFrame))
- self.assertEqual(s.probabilityCol, "probability")
- self.assertEqual(s.labelCol, "label")
- self.assertEqual(s.featuresCol, "features")
- self.assertEqual(s.predictionCol, "prediction")
- objHist = s.objectiveHistory
- self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
- self.assertGreater(s.totalIterations, 0)
- self.assertTrue(isinstance(s.labels, list))
- self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
- self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
- self.assertTrue(isinstance(s.precisionByLabel, list))
- self.assertTrue(isinstance(s.recallByLabel, list))
- self.assertTrue(isinstance(s.fMeasureByLabel(), list))
- self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
- self.assertTrue(isinstance(s.roc, DataFrame))
- self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
- self.assertTrue(isinstance(s.pr, DataFrame))
- self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
- self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
- self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
- self.assertAlmostEqual(s.accuracy, 1.0, 2)
- self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
- self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
- self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
- self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
- self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
- self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
- # test evaluation (with training dataset) produces a summary with same values
- # one check is enough to verify a summary is returned, Scala version runs full test
- sameSummary = model.evaluate(df)
- self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
-
- def test_multiclass_logistic_regression_summary(self):
- df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
- (0.0, 2.0, Vectors.sparse(1, [], [])),
- (2.0, 2.0, Vectors.dense(2.0)),
- (2.0, 2.0, Vectors.dense(1.9))],
- ["label", "weight", "features"])
- lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
- model = lr.fit(df)
- self.assertTrue(model.hasSummary)
- s = model.summary
- # test that api is callable and returns expected types
- self.assertTrue(isinstance(s.predictions, DataFrame))
- self.assertEqual(s.probabilityCol, "probability")
- self.assertEqual(s.labelCol, "label")
- self.assertEqual(s.featuresCol, "features")
- self.assertEqual(s.predictionCol, "prediction")
- objHist = s.objectiveHistory
- self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
- self.assertGreater(s.totalIterations, 0)
- self.assertTrue(isinstance(s.labels, list))
- self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
- self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
- self.assertTrue(isinstance(s.precisionByLabel, list))
- self.assertTrue(isinstance(s.recallByLabel, list))
- self.assertTrue(isinstance(s.fMeasureByLabel(), list))
- self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
- self.assertAlmostEqual(s.accuracy, 0.75, 2)
- self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
- self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
- self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
- self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
- self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2)
- self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2)
- # test evaluation (with training dataset) produces a summary with same values
- # one check is enough to verify a summary is returned, Scala version runs full test
- sameSummary = model.evaluate(df)
- self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)
-
- def test_gaussian_mixture_summary(self):
- data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
- (Vectors.sparse(1, [], []),)]
- df = self.spark.createDataFrame(data, ["features"])
- gmm = GaussianMixture(k=2)
- model = gmm.fit(df)
- self.assertTrue(model.hasSummary)
- s = model.summary
- self.assertTrue(isinstance(s.predictions, DataFrame))
- self.assertEqual(s.probabilityCol, "probability")
- self.assertTrue(isinstance(s.probability, DataFrame))
- self.assertEqual(s.featuresCol, "features")
- self.assertEqual(s.predictionCol, "prediction")
- self.assertTrue(isinstance(s.cluster, DataFrame))
- self.assertEqual(len(s.clusterSizes), 2)
- self.assertEqual(s.k, 2)
- self.assertEqual(s.numIter, 3)
-
- def test_bisecting_kmeans_summary(self):
- data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
- (Vectors.sparse(1, [], []),)]
- df = self.spark.createDataFrame(data, ["features"])
- bkm = BisectingKMeans(k=2)
- model = bkm.fit(df)
- self.assertTrue(model.hasSummary)
- s = model.summary
- self.assertTrue(isinstance(s.predictions, DataFrame))
- self.assertEqual(s.featuresCol, "features")
- self.assertEqual(s.predictionCol, "prediction")
- self.assertTrue(isinstance(s.cluster, DataFrame))
- self.assertEqual(len(s.clusterSizes), 2)
- self.assertEqual(s.k, 2)
- self.assertEqual(s.numIter, 20)
-
- def test_kmeans_summary(self):
- data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
- (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
- df = self.spark.createDataFrame(data, ["features"])
- kmeans = KMeans(k=2, seed=1)
- model = kmeans.fit(df)
- self.assertTrue(model.hasSummary)
- s = model.summary
- self.assertTrue(isinstance(s.predictions, DataFrame))
- self.assertEqual(s.featuresCol, "features")
- self.assertEqual(s.predictionCol, "prediction")
- self.assertTrue(isinstance(s.cluster, DataFrame))
- self.assertEqual(len(s.clusterSizes), 2)
- self.assertEqual(s.k, 2)
- self.assertEqual(s.numIter, 1)
-
-
-class KMeansTests(SparkSessionTestCase):
-
- def test_kmeans_cosine_distance(self):
- data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),),
- (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),),
- (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)]
- df = self.spark.createDataFrame(data, ["features"])
- kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine")
- model = kmeans.fit(df)
- result = model.transform(df).collect()
- self.assertTrue(result[0].prediction == result[1].prediction)
- self.assertTrue(result[2].prediction == result[3].prediction)
- self.assertTrue(result[4].prediction == result[5].prediction)
-
-
-class OneVsRestTests(SparkSessionTestCase):
-
- def test_copy(self):
- df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
- (1.0, Vectors.sparse(2, [], [])),
- (2.0, Vectors.dense(0.5, 0.5))],
- ["label", "features"])
- lr = LogisticRegression(maxIter=5, regParam=0.01)
- ovr = OneVsRest(classifier=lr)
- ovr1 = ovr.copy({lr.maxIter: 10})
- self.assertEqual(ovr.getClassifier().getMaxIter(), 5)
- self.assertEqual(ovr1.getClassifier().getMaxIter(), 10)
- model = ovr.fit(df)
- model1 = model.copy({model.predictionCol: "indexed"})
- self.assertEqual(model1.getPredictionCol(), "indexed")
-
- def test_output_columns(self):
- df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
- (1.0, Vectors.sparse(2, [], [])),
- (2.0, Vectors.dense(0.5, 0.5))],
- ["label", "features"])
- lr = LogisticRegression(maxIter=5, regParam=0.01)
- ovr = OneVsRest(classifier=lr, parallelism=1)
- model = ovr.fit(df)
- output = model.transform(df)
- self.assertEqual(output.columns, ["label", "features", "prediction"])
-
- def test_parallelism_doesnt_change_output(self):
- df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
- (1.0, Vectors.sparse(2, [], [])),
- (2.0, Vectors.dense(0.5, 0.5))],
- ["label", "features"])
- ovrPar1 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=1)
- modelPar1 = ovrPar1.fit(df)
- ovrPar2 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=2)
- modelPar2 = ovrPar2.fit(df)
- for i, model in enumerate(modelPar1.models):
- self.assertTrue(np.allclose(model.coefficients.toArray(),
- modelPar2.models[i].coefficients.toArray(), atol=1E-4))
- self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4))
-
- def test_support_for_weightCol(self):
- df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
- (1.0, Vectors.sparse(2, [], []), 1.0),
- (2.0, Vectors.dense(0.5, 0.5), 1.0)],
- ["label", "features", "weight"])
- # classifier inherits hasWeightCol
- lr = LogisticRegression(maxIter=5, regParam=0.01)
- ovr = OneVsRest(classifier=lr, weightCol="weight")
- self.assertIsNotNone(ovr.fit(df))
- # classifier doesn't inherit hasWeightCol
- dt = DecisionTreeClassifier()
- ovr2 = OneVsRest(classifier=dt, weightCol="weight")
- self.assertIsNotNone(ovr2.fit(df))
-
-
-class HashingTFTest(SparkSessionTestCase):
-
- def test_apply_binary_term_freqs(self):
-
- df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"])
- n = 10
- hashingTF = HashingTF()
- hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True)
- output = hashingTF.transform(df)
- features = output.select("features").first().features.toArray()
- expected = Vectors.dense([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).toArray()
- for i in range(0, n):
- self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) +
- ": expected " + str(expected[i]) + ", got " + str(features[i]))
-
-
-class GeneralizedLinearRegressionTest(SparkSessionTestCase):
-
- def test_tweedie_distribution(self):
-
- df = self.spark.createDataFrame(
- [(1.0, Vectors.dense(0.0, 0.0)),
- (1.0, Vectors.dense(1.0, 2.0)),
- (2.0, Vectors.dense(0.0, 0.0)),
- (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"])
-
- glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6)
- model = glr.fit(df)
- self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4))
- self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4))
-
- model2 = glr.setLinkPower(-1.0).fit(df)
- self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4))
- self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))
-
- def test_offset(self):
-
- df = self.spark.createDataFrame(
- [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)),
- (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)),
- (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)),
- (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"])
-
- glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset")
- model = glr.fit(df)
- self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581],
- atol=1E-4))
- self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4))
-
-
-class LinearRegressionTest(SparkSessionTestCase):
-
- def test_linear_regression_with_huber_loss(self):
-
- data_path = "data/mllib/sample_linear_regression_data.txt"
- df = self.spark.read.format("libsvm").load(data_path)
-
- lir = LinearRegression(loss="huber", epsilon=2.0)
- model = lir.fit(df)
-
- expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537,
- 1.2612, -0.333, -0.5694, -0.6311, 0.6053]
- expectedIntercept = 0.1607
- expectedScale = 9.758
-
- self.assertTrue(
- np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3))
- self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3))
- self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3))
-
-
-class LogisticRegressionTest(SparkSessionTestCase):
-
- def test_binomial_logistic_regression_with_bound(self):
-
- df = self.spark.createDataFrame(
- [(1.0, 1.0, Vectors.dense(0.0, 5.0)),
- (0.0, 2.0, Vectors.dense(1.0, 2.0)),
- (1.0, 3.0, Vectors.dense(2.0, 1.0)),
- (0.0, 4.0, Vectors.dense(3.0, 3.0)), ], ["label", "weight", "features"])
-
- lor = LogisticRegression(regParam=0.01, weightCol="weight",
- lowerBoundsOnCoefficients=Matrices.dense(1, 2, [-1.0, -1.0]),
- upperBoundsOnIntercepts=Vectors.dense(0.0))
- model = lor.fit(df)
- self.assertTrue(
- np.allclose(model.coefficients.toArray(), [-0.2944, -0.0484], atol=1E-4))
- self.assertTrue(np.isclose(model.intercept, 0.0, atol=1E-4))
-
- def test_multinomial_logistic_regression_with_bound(self):
-
- data_path = "data/mllib/sample_multiclass_classification_data.txt"
- df = self.spark.read.format("libsvm").load(data_path)
-
- lor = LogisticRegression(regParam=0.01,
- lowerBoundsOnCoefficients=Matrices.dense(3, 4, range(12)),
- upperBoundsOnIntercepts=Vectors.dense(0.0, 0.0, 0.0))
- model = lor.fit(df)
- expected = [[4.593, 4.5516, 9.0099, 12.2904],
- [1.0, 8.1093, 7.0, 10.0],
- [3.041, 5.0, 8.0, 11.0]]
- for i in range(0, len(expected)):
- self.assertTrue(
- np.allclose(model.coefficientMatrix.toArray()[i], expected[i], atol=1E-4))
- self.assertTrue(
- np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1E-4))
-
-
-class MultilayerPerceptronClassifierTest(SparkSessionTestCase):
-
- def test_raw_and_probability_prediction(self):
-
- data_path = "data/mllib/sample_multiclass_classification_data.txt"
- df = self.spark.read.format("libsvm").load(data_path)
-
- mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[4, 5, 4, 3],
- blockSize=128, seed=123)
- model = mlp.fit(df)
- test = self.sc.parallelize([Row(features=Vectors.dense(0.1, 0.1, 0.25, 0.25))]).toDF()
- result = model.transform(test).head()
- expected_prediction = 2.0
- expected_probability = [0.0, 0.0, 1.0]
- expected_rawPrediction = [57.3955, -124.5462, 67.9943]
- self.assertTrue(result.prediction, expected_prediction)
- self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4))
- self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4))
-
-
-class FPGrowthTests(SparkSessionTestCase):
- def setUp(self):
- super(FPGrowthTests, self).setUp()
- self.data = self.spark.createDataFrame(
- [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )],
- ["items"])
-
- def test_association_rules(self):
- fp = FPGrowth()
- fpm = fp.fit(self.data)
-
- expected_association_rules = self.spark.createDataFrame(
- [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)],
- ["antecedent", "consequent", "confidence", "lift"]
- )
- actual_association_rules = fpm.associationRules
-
- self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0)
- self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0)
-
- def test_freq_itemsets(self):
- fp = FPGrowth()
- fpm = fp.fit(self.data)
-
- expected_freq_itemsets = self.spark.createDataFrame(
- [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)],
- ["items", "freq"]
- )
- actual_freq_itemsets = fpm.freqItemsets
-
- self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0)
- self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0)
-
- def tearDown(self):
- del self.data
-
-
-class ImageReaderTest(SparkSessionTestCase):
-
- def test_read_images(self):
- data_path = 'data/mllib/images/origin/kittens'
- df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
- self.assertEqual(df.count(), 4)
- first_row = df.take(1)[0][0]
- array = ImageSchema.toNDArray(first_row)
- self.assertEqual(len(array), first_row[1])
- self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row)
- self.assertEqual(df.schema, ImageSchema.imageSchema)
- self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema)
- expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24}
- self.assertEqual(ImageSchema.ocvTypes, expected)
- expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data']
- self.assertEqual(ImageSchema.imageFields, expected)
- self.assertEqual(ImageSchema.undefinedImageType, "Undefined")
-
- with QuietTest(self.sc):
- self.assertRaisesRegexp(
- TypeError,
- "image argument should be pyspark.sql.types.Row; however",
- lambda: ImageSchema.toNDArray("a"))
-
- with QuietTest(self.sc):
- self.assertRaisesRegexp(
- ValueError,
- "image argument should have attributes specified in",
- lambda: ImageSchema.toNDArray(Row(a=1)))
-
- with QuietTest(self.sc):
- self.assertRaisesRegexp(
- TypeError,
- "array argument should be numpy.ndarray; however, it got",
- lambda: ImageSchema.toImage("a"))
-
-
-class ImageReaderTest2(PySparkTestCase):
-
- @classmethod
- def setUpClass(cls):
- super(ImageReaderTest2, cls).setUpClass()
- cls.hive_available = True
- # Note that here we enable Hive's support.
- cls.spark = None
- try:
- cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
- except py4j.protocol.Py4JError:
- cls.tearDownClass()
- cls.hive_available = False
- except TypeError:
- cls.tearDownClass()
- cls.hive_available = False
- if cls.hive_available:
- cls.spark = HiveContext._createForTesting(cls.sc)
-
- def setUp(self):
- if not self.hive_available:
- self.skipTest("Hive is not available.")
-
- @classmethod
- def tearDownClass(cls):
- super(ImageReaderTest2, cls).tearDownClass()
- if cls.spark is not None:
- cls.spark.sparkSession.stop()
- cls.spark = None
-
- def test_read_images_multiple_times(self):
- # This test case is to check if `ImageSchema.readImages` tries to
- # initiate Hive client multiple times. See SPARK-22651.
- data_path = 'data/mllib/images/origin/kittens'
- ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
- ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
-
-
-class ALSTest(SparkSessionTestCase):
-
- def test_storage_levels(self):
- df = self.spark.createDataFrame(
- [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
- ["user", "item", "rating"])
- als = ALS().setMaxIter(1).setRank(1)
- # test default params
- als.fit(df)
- self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_AND_DISK")
- self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_AND_DISK")
- self.assertEqual(als.getFinalStorageLevel(), "MEMORY_AND_DISK")
- self.assertEqual(als._java_obj.getFinalStorageLevel(), "MEMORY_AND_DISK")
- # test non-default params
- als.setIntermediateStorageLevel("MEMORY_ONLY_2")
- als.setFinalStorageLevel("DISK_ONLY")
- als.fit(df)
- self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_ONLY_2")
- self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_ONLY_2")
- self.assertEqual(als.getFinalStorageLevel(), "DISK_ONLY")
- self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY")
-
-
-class DefaultValuesTests(PySparkTestCase):
- """
- Test :py:class:`JavaParams` classes to see if their default Param values match
- those in their Scala counterparts.
- """
-
- def test_java_params(self):
- import pyspark.ml.feature
- import pyspark.ml.classification
- import pyspark.ml.clustering
- import pyspark.ml.evaluation
- import pyspark.ml.pipeline
- import pyspark.ml.recommendation
- import pyspark.ml.regression
-
- modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering,
- pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation,
- pyspark.ml.regression]
- for module in modules:
- for name, cls in inspect.getmembers(module, inspect.isclass):
- if not name.endswith('Model') and not name.endswith('Params')\
- and issubclass(cls, JavaParams) and not inspect.isabstract(cls):
- # NOTE: disable check_params_exist until there is parity with Scala API
- ParamTests.check_params(self, cls(), check_params_exist=False)
-
- # Additional classes that need explicit construction
- from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel
- ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'),
- check_params_exist=False)
- ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'),
- check_params_exist=False)
-
-
-def _squared_distance(a, b):
- if isinstance(a, Vector):
- return a.squared_distance(b)
- else:
- return b.squared_distance(a)
-
-
-class VectorTests(MLlibTestCase):
-
- def _test_serialize(self, v):
- self.assertEqual(v, ser.loads(ser.dumps(v)))
- jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v)))
- nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec)))
- self.assertEqual(v, nv)
- vs = [v] * 100
- jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs)))
- nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs)))
- self.assertEqual(vs, nvs)
-
- def test_serialize(self):
- self._test_serialize(DenseVector(range(10)))
- self._test_serialize(DenseVector(array([1., 2., 3., 4.])))
- self._test_serialize(DenseVector(pyarray.array('d', range(10))))
- self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
- self._test_serialize(SparseVector(3, {}))
- self._test_serialize(DenseMatrix(2, 3, range(6)))
- sm1 = SparseMatrix(
- 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
- self._test_serialize(sm1)
-
- def test_dot(self):
- sv = SparseVector(4, {1: 1, 3: 2})
- dv = DenseVector(array([1., 2., 3., 4.]))
- lst = DenseVector([1, 2, 3, 4])
- mat = array([[1., 2., 3., 4.],
- [1., 2., 3., 4.],
- [1., 2., 3., 4.],
- [1., 2., 3., 4.]])
- arr = pyarray.array('d', [0, 1, 2, 3])
- self.assertEqual(10.0, sv.dot(dv))
- self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat)))
- self.assertEqual(30.0, dv.dot(dv))
- self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat)))
- self.assertEqual(30.0, lst.dot(dv))
- self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat)))
- self.assertEqual(7.0, sv.dot(arr))
-
- def test_squared_distance(self):
- sv = SparseVector(4, {1: 1, 3: 2})
- dv = DenseVector(array([1., 2., 3., 4.]))
- lst = DenseVector([4, 3, 2, 1])
- lst1 = [4, 3, 2, 1]
- arr = pyarray.array('d', [0, 2, 1, 3])
- narr = array([0, 2, 1, 3])
- self.assertEqual(15.0, _squared_distance(sv, dv))
- self.assertEqual(25.0, _squared_distance(sv, lst))
- self.assertEqual(20.0, _squared_distance(dv, lst))
- self.assertEqual(15.0, _squared_distance(dv, sv))
- self.assertEqual(25.0, _squared_distance(lst, sv))
- self.assertEqual(20.0, _squared_distance(lst, dv))
- self.assertEqual(0.0, _squared_distance(sv, sv))
- self.assertEqual(0.0, _squared_distance(dv, dv))
- self.assertEqual(0.0, _squared_distance(lst, lst))
- self.assertEqual(25.0, _squared_distance(sv, lst1))
- self.assertEqual(3.0, _squared_distance(sv, arr))
- self.assertEqual(3.0, _squared_distance(sv, narr))
-
- def test_hash(self):
- v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
- v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
- v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
- v4 = SparseVector(4, [(1, 1.0), (3, 2.5)])
- self.assertEqual(hash(v1), hash(v2))
- self.assertEqual(hash(v1), hash(v3))
- self.assertEqual(hash(v2), hash(v3))
- self.assertFalse(hash(v1) == hash(v4))
- self.assertFalse(hash(v2) == hash(v4))
-
- def test_eq(self):
- v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
- v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
- v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
- v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
- v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
- v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
- self.assertEqual(v1, v2)
- self.assertEqual(v1, v3)
- self.assertFalse(v2 == v4)
- self.assertFalse(v1 == v5)
- self.assertFalse(v1 == v6)
-
- def test_equals(self):
- indices = [1, 2, 4]
- values = [1., 3., 2.]
- self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.]))
- self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.]))
- self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.]))
- self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.]))
-
- def test_conversion(self):
- # numpy arrays should be automatically upcast to float64
- # tests for fix of [SPARK-5089]
- v = array([1, 2, 3, 4], dtype='float64')
- dv = DenseVector(v)
- self.assertTrue(dv.array.dtype == 'float64')
- v = array([1, 2, 3, 4], dtype='float32')
- dv = DenseVector(v)
- self.assertTrue(dv.array.dtype == 'float64')
-
- def test_sparse_vector_indexing(self):
- sv = SparseVector(5, {1: 1, 3: 2})
- self.assertEqual(sv[0], 0.)
- self.assertEqual(sv[3], 2.)
- self.assertEqual(sv[1], 1.)
- self.assertEqual(sv[2], 0.)
- self.assertEqual(sv[4], 0.)
- self.assertEqual(sv[-1], 0.)
- self.assertEqual(sv[-2], 2.)
- self.assertEqual(sv[-3], 0.)
- self.assertEqual(sv[-5], 0.)
- for ind in [5, -6]:
- self.assertRaises(IndexError, sv.__getitem__, ind)
- for ind in [7.8, '1']:
- self.assertRaises(TypeError, sv.__getitem__, ind)
-
- zeros = SparseVector(4, {})
- self.assertEqual(zeros[0], 0.0)
- self.assertEqual(zeros[3], 0.0)
- for ind in [4, -5]:
- self.assertRaises(IndexError, zeros.__getitem__, ind)
-
- empty = SparseVector(0, {})
- for ind in [-1, 0, 1]:
- self.assertRaises(IndexError, empty.__getitem__, ind)
-
- def test_sparse_vector_iteration(self):
- self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0])
- self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0])
-
- def test_matrix_indexing(self):
- mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
- expected = [[0, 6], [1, 8], [4, 10]]
- for i in range(3):
- for j in range(2):
- self.assertEqual(mat[i, j], expected[i][j])
-
- for i, j in [(-1, 0), (4, 1), (3, 4)]:
- self.assertRaises(IndexError, mat.__getitem__, (i, j))
-
- def test_repr_dense_matrix(self):
- mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
- self.assertTrue(
- repr(mat),
- 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
-
- mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True)
- self.assertTrue(
- repr(mat),
- 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
-
- mat = DenseMatrix(6, 3, zeros(18))
- self.assertTrue(
- repr(mat),
- 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \
- 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)')
-
- def test_repr_sparse_matrix(self):
- sm1t = SparseMatrix(
- 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
- isTransposed=True)
- self.assertTrue(
- repr(sm1t),
- 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)')
-
- indices = tile(arange(6), 3)
- values = ones(18)
- sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values)
- self.assertTrue(
- repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \
- [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \
- [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \
- 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)")
-
- self.assertTrue(
- str(sm),
- "6 X 3 CSCMatrix\n\
- (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\
- (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\
- (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..")
-
- sm = SparseMatrix(1, 18, zeros(19), [], [])
- self.assertTrue(
- repr(sm),
- 'SparseMatrix(1, 18, \
- [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)')
-
- def test_sparse_matrix(self):
- # Test sparse matrix creation.
- sm1 = SparseMatrix(
- 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
- self.assertEqual(sm1.numRows, 3)
- self.assertEqual(sm1.numCols, 4)
- self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4])
- self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2])
- self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0])
- self.assertTrue(
- repr(sm1),
- 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)')
-
- # Test indexing
- expected = [
- [0, 0, 0, 0],
- [1, 0, 4, 0],
- [2, 0, 5, 0]]
-
- for i in range(3):
- for j in range(4):
- self.assertEqual(expected[i][j], sm1[i, j])
- self.assertTrue(array_equal(sm1.toArray(), expected))
-
- for i, j in [(-1, 1), (4, 3), (3, 5)]:
- self.assertRaises(IndexError, sm1.__getitem__, (i, j))
-
- # Test conversion to dense and sparse.
- smnew = sm1.toDense().toSparse()
- self.assertEqual(sm1.numRows, smnew.numRows)
- self.assertEqual(sm1.numCols, smnew.numCols)
- self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs))
- self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices))
- self.assertTrue(array_equal(sm1.values, smnew.values))
-
- sm1t = SparseMatrix(
- 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
- isTransposed=True)
- self.assertEqual(sm1t.numRows, 3)
- self.assertEqual(sm1t.numCols, 4)
- self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5])
- self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2])
- self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0])
-
- expected = [
- [3, 2, 0, 0],
- [0, 0, 4, 0],
- [9, 0, 8, 0]]
-
- for i in range(3):
- for j in range(4):
- self.assertEqual(expected[i][j], sm1t[i, j])
- self.assertTrue(array_equal(sm1t.toArray(), expected))
-
- def test_dense_matrix_is_transposed(self):
- mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True)
- mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9])
- self.assertEqual(mat1, mat)
-
- expected = [[0, 4], [1, 6], [3, 9]]
- for i in range(3):
- for j in range(2):
- self.assertEqual(mat1[i, j], expected[i][j])
- self.assertTrue(array_equal(mat1.toArray(), expected))
-
- sm = mat1.toSparse()
- self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2]))
- self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
- self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
-
- def test_norms(self):
- a = DenseVector([0, 2, 3, -1])
- self.assertAlmostEqual(a.norm(2), 3.742, 3)
- self.assertTrue(a.norm(1), 6)
- self.assertTrue(a.norm(inf), 3)
- a = SparseVector(4, [0, 2], [3, -4])
- self.assertAlmostEqual(a.norm(2), 5)
- self.assertTrue(a.norm(1), 7)
- self.assertTrue(a.norm(inf), 4)
-
- tmp = SparseVector(4, [0, 2], [3, 0])
- self.assertEqual(tmp.numNonzeros(), 1)
-
-
-class VectorUDTTests(MLlibTestCase):
-
- dv0 = DenseVector([])
- dv1 = DenseVector([1.0, 2.0])
- sv0 = SparseVector(2, [], [])
- sv1 = SparseVector(2, [1], [2.0])
- udt = VectorUDT()
-
- def test_json_schema(self):
- self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt)
-
- def test_serialization(self):
- for v in [self.dv0, self.dv1, self.sv0, self.sv1]:
- self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
-
- def test_infer_schema(self):
- rdd = self.sc.parallelize([Row(label=1.0, features=self.dv1),
- Row(label=0.0, features=self.sv1)])
- df = rdd.toDF()
- schema = df.schema
- field = [f for f in schema.fields if f.name == "features"][0]
- self.assertEqual(field.dataType, self.udt)
- vectors = df.rdd.map(lambda p: p.features).collect()
- self.assertEqual(len(vectors), 2)
- for v in vectors:
- if isinstance(v, SparseVector):
- self.assertEqual(v, self.sv1)
- elif isinstance(v, DenseVector):
- self.assertEqual(v, self.dv1)
- else:
- raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
-
-
-class MatrixUDTTests(MLlibTestCase):
-
- dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
- dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True)
- sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0])
- sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True)
- udt = MatrixUDT()
-
- def test_json_schema(self):
- self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt)
-
- def test_serialization(self):
- for m in [self.dm1, self.dm2, self.sm1, self.sm2]:
- self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
-
- def test_infer_schema(self):
- rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
- df = rdd.toDF()
- schema = df.schema
- self.assertTrue(schema.fields[1].dataType, self.udt)
- matrices = df.rdd.map(lambda x: x._2).collect()
- self.assertEqual(len(matrices), 2)
- for m in matrices:
- if isinstance(m, DenseMatrix):
- self.assertTrue(m, self.dm1)
- elif isinstance(m, SparseMatrix):
- self.assertTrue(m, self.sm1)
- else:
- raise ValueError("Expected a matrix but got type %r" % type(m))
-
-
-class WrapperTests(MLlibTestCase):
-
- def test_new_java_array(self):
- # test array of strings
- str_list = ["a", "b", "c"]
- java_class = self.sc._gateway.jvm.java.lang.String
- java_array = JavaWrapper._new_java_array(str_list, java_class)
- self.assertEqual(_java2py(self.sc, java_array), str_list)
- # test array of integers
- int_list = [1, 2, 3]
- java_class = self.sc._gateway.jvm.java.lang.Integer
- java_array = JavaWrapper._new_java_array(int_list, java_class)
- self.assertEqual(_java2py(self.sc, java_array), int_list)
- # test array of floats
- float_list = [0.1, 0.2, 0.3]
- java_class = self.sc._gateway.jvm.java.lang.Double
- java_array = JavaWrapper._new_java_array(float_list, java_class)
- self.assertEqual(_java2py(self.sc, java_array), float_list)
- # test array of bools
- bool_list = [False, True, True]
- java_class = self.sc._gateway.jvm.java.lang.Boolean
- java_array = JavaWrapper._new_java_array(bool_list, java_class)
- self.assertEqual(_java2py(self.sc, java_array), bool_list)
- # test array of Java DenseVectors
- v1 = DenseVector([0.0, 1.0])
- v2 = DenseVector([1.0, 0.0])
- vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)]
- java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector
- java_array = JavaWrapper._new_java_array(vec_java_list, java_class)
- self.assertEqual(_java2py(self.sc, java_array), [v1, v2])
- # test empty array
- java_class = self.sc._gateway.jvm.java.lang.Integer
- java_array = JavaWrapper._new_java_array([], java_class)
- self.assertEqual(_java2py(self.sc, java_array), [])
-
-
-class ChiSquareTestTests(SparkSessionTestCase):
-
- def test_chisquaretest(self):
- data = [[0, Vectors.dense([0, 1, 2])],
- [1, Vectors.dense([1, 1, 1])],
- [2, Vectors.dense([2, 1, 0])]]
- df = self.spark.createDataFrame(data, ['label', 'feat'])
- res = ChiSquareTest.test(df, 'feat', 'label')
- # This line is hitting the collect bug described in #17218, commented for now.
- # pValues = res.select("degreesOfFreedom").collect())
- self.assertIsInstance(res, DataFrame)
- fieldNames = set(field.name for field in res.schema.fields)
- expectedFields = ["pValues", "degreesOfFreedom", "statistics"]
- self.assertTrue(all(field in fieldNames for field in expectedFields))
-
-
-class UnaryTransformerTests(SparkSessionTestCase):
-
- def test_unary_transformer_validate_input_type(self):
- shiftVal = 3
- transformer = MockUnaryTransformer(shiftVal=shiftVal)\
- .setInputCol("input").setOutputCol("output")
-
- # should not raise any errors
- transformer.validateInputType(DoubleType())
-
- with self.assertRaises(TypeError):
- # passing the wrong input type should raise an error
- transformer.validateInputType(IntegerType())
-
- def test_unary_transformer_transform(self):
- shiftVal = 3
- transformer = MockUnaryTransformer(shiftVal=shiftVal)\
- .setInputCol("input").setOutputCol("output")
-
- df = self.spark.range(0, 10).toDF('input')
- df = df.withColumn("input", df.input.cast(dataType="double"))
-
- transformed_df = transformer.transform(df)
- results = transformed_df.select("input", "output").collect()
-
- for res in results:
- self.assertEqual(res.input + shiftVal, res.output)
-
-
-class EstimatorTest(unittest.TestCase):
-
- def testDefaultFitMultiple(self):
- N = 4
- data = MockDataset()
- estimator = MockEstimator()
- params = [{estimator.fake: i} for i in range(N)]
- modelIter = estimator.fitMultiple(data, params)
- indexList = []
- for index, model in modelIter:
- self.assertEqual(model.getFake(), index)
- indexList.append(index)
- self.assertEqual(sorted(indexList), list(range(N)))
-
-
-if __name__ == "__main__":
- from pyspark.ml.tests import *
-
- runner = unishark.BufferedTestRunner(
- reporters=[unishark.XUnitReporter('target/test-reports/pyspark.ml_{}'.format(
- os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))])
- unittest.main(testRunner=runner, verbosity=2)
diff --git a/python/pyspark/ml/tests/__init__.py b/python/pyspark/ml/tests/__init__.py
new file mode 100644
index 0000000000000..cce3acad34a49
--- /dev/null
+++ b/python/pyspark/ml/tests/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py
new file mode 100644
index 0000000000000..516bb563402e0
--- /dev/null
+++ b/python/pyspark/ml/tests/test_algorithms.py
@@ -0,0 +1,340 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from shutil import rmtree
+import tempfile
+import unittest
+
+import numpy as np
+
+from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, \
+ MultilayerPerceptronClassifier, OneVsRest
+from pyspark.ml.clustering import DistributedLDAModel, KMeans, LocalLDAModel, LDA, LDAModel
+from pyspark.ml.fpm import FPGrowth
+from pyspark.ml.linalg import Matrices, Vectors
+from pyspark.ml.recommendation import ALS
+from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
+from pyspark.sql import Row
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class LogisticRegressionTest(SparkSessionTestCase):
+
+ def test_binomial_logistic_regression_with_bound(self):
+
+ df = self.spark.createDataFrame(
+ [(1.0, 1.0, Vectors.dense(0.0, 5.0)),
+ (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+ (1.0, 3.0, Vectors.dense(2.0, 1.0)),
+ (0.0, 4.0, Vectors.dense(3.0, 3.0)), ], ["label", "weight", "features"])
+
+ lor = LogisticRegression(regParam=0.01, weightCol="weight",
+ lowerBoundsOnCoefficients=Matrices.dense(1, 2, [-1.0, -1.0]),
+ upperBoundsOnIntercepts=Vectors.dense(0.0))
+ model = lor.fit(df)
+ self.assertTrue(
+ np.allclose(model.coefficients.toArray(), [-0.2944, -0.0484], atol=1E-4))
+ self.assertTrue(np.isclose(model.intercept, 0.0, atol=1E-4))
+
+ def test_multinomial_logistic_regression_with_bound(self):
+
+ data_path = "data/mllib/sample_multiclass_classification_data.txt"
+ df = self.spark.read.format("libsvm").load(data_path)
+
+ lor = LogisticRegression(regParam=0.01,
+ lowerBoundsOnCoefficients=Matrices.dense(3, 4, range(12)),
+ upperBoundsOnIntercepts=Vectors.dense(0.0, 0.0, 0.0))
+ model = lor.fit(df)
+ expected = [[4.593, 4.5516, 9.0099, 12.2904],
+ [1.0, 8.1093, 7.0, 10.0],
+ [3.041, 5.0, 8.0, 11.0]]
+ for i in range(0, len(expected)):
+ self.assertTrue(
+ np.allclose(model.coefficientMatrix.toArray()[i], expected[i], atol=1E-4))
+ self.assertTrue(
+ np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1E-4))
+
+
+class MultilayerPerceptronClassifierTest(SparkSessionTestCase):
+
+ def test_raw_and_probability_prediction(self):
+
+ data_path = "data/mllib/sample_multiclass_classification_data.txt"
+ df = self.spark.read.format("libsvm").load(data_path)
+
+ mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[4, 5, 4, 3],
+ blockSize=128, seed=123)
+ model = mlp.fit(df)
+ test = self.sc.parallelize([Row(features=Vectors.dense(0.1, 0.1, 0.25, 0.25))]).toDF()
+ result = model.transform(test).head()
+ expected_prediction = 2.0
+ expected_probability = [0.0, 0.0, 1.0]
+ expected_rawPrediction = [57.3955, -124.5462, 67.9943]
+ self.assertTrue(result.prediction, expected_prediction)
+ self.assertTrue(np.allclose(result.probability, expected_probability, atol=1E-4))
+ self.assertTrue(np.allclose(result.rawPrediction, expected_rawPrediction, atol=1E-4))
+
+
+class OneVsRestTests(SparkSessionTestCase):
+
+ def test_copy(self):
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ (1.0, Vectors.sparse(2, [], [])),
+ (2.0, Vectors.dense(0.5, 0.5))],
+ ["label", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ovr = OneVsRest(classifier=lr)
+ ovr1 = ovr.copy({lr.maxIter: 10})
+ self.assertEqual(ovr.getClassifier().getMaxIter(), 5)
+ self.assertEqual(ovr1.getClassifier().getMaxIter(), 10)
+ model = ovr.fit(df)
+ model1 = model.copy({model.predictionCol: "indexed"})
+ self.assertEqual(model1.getPredictionCol(), "indexed")
+
+ def test_output_columns(self):
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ (1.0, Vectors.sparse(2, [], [])),
+ (2.0, Vectors.dense(0.5, 0.5))],
+ ["label", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ovr = OneVsRest(classifier=lr, parallelism=1)
+ model = ovr.fit(df)
+ output = model.transform(df)
+ self.assertEqual(output.columns, ["label", "features", "prediction"])
+
+ def test_parallelism_doesnt_change_output(self):
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ (1.0, Vectors.sparse(2, [], [])),
+ (2.0, Vectors.dense(0.5, 0.5))],
+ ["label", "features"])
+ ovrPar1 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=1)
+ modelPar1 = ovrPar1.fit(df)
+ ovrPar2 = OneVsRest(classifier=LogisticRegression(maxIter=5, regParam=.01), parallelism=2)
+ modelPar2 = ovrPar2.fit(df)
+ for i, model in enumerate(modelPar1.models):
+ self.assertTrue(np.allclose(model.coefficients.toArray(),
+ modelPar2.models[i].coefficients.toArray(), atol=1E-4))
+ self.assertTrue(np.allclose(model.intercept, modelPar2.models[i].intercept, atol=1E-4))
+
+ def test_support_for_weightCol(self):
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
+ (1.0, Vectors.sparse(2, [], []), 1.0),
+ (2.0, Vectors.dense(0.5, 0.5), 1.0)],
+ ["label", "features", "weight"])
+ # classifier inherits hasWeightCol
+ lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ovr = OneVsRest(classifier=lr, weightCol="weight")
+ self.assertIsNotNone(ovr.fit(df))
+ # classifier doesn't inherit hasWeightCol
+ dt = DecisionTreeClassifier()
+ ovr2 = OneVsRest(classifier=dt, weightCol="weight")
+ self.assertIsNotNone(ovr2.fit(df))
+
+
+class KMeansTests(SparkSessionTestCase):
+
+ def test_kmeans_cosine_distance(self):
+ data = [(Vectors.dense([1.0, 1.0]),), (Vectors.dense([10.0, 10.0]),),
+ (Vectors.dense([1.0, 0.5]),), (Vectors.dense([10.0, 4.4]),),
+ (Vectors.dense([-1.0, 1.0]),), (Vectors.dense([-100.0, 90.0]),)]
+ df = self.spark.createDataFrame(data, ["features"])
+ kmeans = KMeans(k=3, seed=1, distanceMeasure="cosine")
+ model = kmeans.fit(df)
+ result = model.transform(df).collect()
+ self.assertTrue(result[0].prediction == result[1].prediction)
+ self.assertTrue(result[2].prediction == result[3].prediction)
+ self.assertTrue(result[4].prediction == result[5].prediction)
+
+
+class LDATest(SparkSessionTestCase):
+
+ def _compare(self, m1, m2):
+ """
+ Temp method for comparing instances.
+ TODO: Replace with generic implementation once SPARK-14706 is merged.
+ """
+ self.assertEqual(m1.uid, m2.uid)
+ self.assertEqual(type(m1), type(m2))
+ self.assertEqual(len(m1.params), len(m2.params))
+ for p in m1.params:
+ if m1.isDefined(p):
+ self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
+ self.assertEqual(p.parent, m2.getParam(p.name).parent)
+ if isinstance(m1, LDAModel):
+ self.assertEqual(m1.vocabSize(), m2.vocabSize())
+ self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix())
+
+ def test_persistence(self):
+ # Test save/load for LDA, LocalLDAModel, DistributedLDAModel.
+ df = self.spark.createDataFrame([
+ [1, Vectors.dense([0.0, 1.0])],
+ [2, Vectors.sparse(2, {0: 1.0})],
+ ], ["id", "features"])
+ # Fit model
+ lda = LDA(k=2, seed=1, optimizer="em")
+ distributedModel = lda.fit(df)
+ self.assertTrue(distributedModel.isDistributed())
+ localModel = distributedModel.toLocal()
+ self.assertFalse(localModel.isDistributed())
+ # Define paths
+ path = tempfile.mkdtemp()
+ lda_path = path + "/lda"
+ dist_model_path = path + "/distLDAModel"
+ local_model_path = path + "/localLDAModel"
+ # Test LDA
+ lda.save(lda_path)
+ lda2 = LDA.load(lda_path)
+ self._compare(lda, lda2)
+ # Test DistributedLDAModel
+ distributedModel.save(dist_model_path)
+ distributedModel2 = DistributedLDAModel.load(dist_model_path)
+ self._compare(distributedModel, distributedModel2)
+ # Test LocalLDAModel
+ localModel.save(local_model_path)
+ localModel2 = LocalLDAModel.load(local_model_path)
+ self._compare(localModel, localModel2)
+ # Clean up
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+
+class FPGrowthTests(SparkSessionTestCase):
+ def setUp(self):
+ super(FPGrowthTests, self).setUp()
+ self.data = self.spark.createDataFrame(
+ [([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )],
+ ["items"])
+
+ def test_association_rules(self):
+ fp = FPGrowth()
+ fpm = fp.fit(self.data)
+
+ expected_association_rules = self.spark.createDataFrame(
+ [([3], [1], 1.0, 1.0), ([2], [1], 1.0, 1.0)],
+ ["antecedent", "consequent", "confidence", "lift"]
+ )
+ actual_association_rules = fpm.associationRules
+
+ self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0)
+ self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0)
+
+ def test_freq_itemsets(self):
+ fp = FPGrowth()
+ fpm = fp.fit(self.data)
+
+ expected_freq_itemsets = self.spark.createDataFrame(
+ [([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)],
+ ["items", "freq"]
+ )
+ actual_freq_itemsets = fpm.freqItemsets
+
+ self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0)
+ self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0)
+
+ def tearDown(self):
+ del self.data
+
+
+class ALSTest(SparkSessionTestCase):
+
+ def test_storage_levels(self):
+ df = self.spark.createDataFrame(
+ [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
+ ["user", "item", "rating"])
+ als = ALS().setMaxIter(1).setRank(1)
+ # test default params
+ als.fit(df)
+ self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_AND_DISK")
+ self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_AND_DISK")
+ self.assertEqual(als.getFinalStorageLevel(), "MEMORY_AND_DISK")
+ self.assertEqual(als._java_obj.getFinalStorageLevel(), "MEMORY_AND_DISK")
+ # test non-default params
+ als.setIntermediateStorageLevel("MEMORY_ONLY_2")
+ als.setFinalStorageLevel("DISK_ONLY")
+ als.fit(df)
+ self.assertEqual(als.getIntermediateStorageLevel(), "MEMORY_ONLY_2")
+ self.assertEqual(als._java_obj.getIntermediateStorageLevel(), "MEMORY_ONLY_2")
+ self.assertEqual(als.getFinalStorageLevel(), "DISK_ONLY")
+ self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY")
+
+
+class GeneralizedLinearRegressionTest(SparkSessionTestCase):
+
+ def test_tweedie_distribution(self):
+
+ df = self.spark.createDataFrame(
+ [(1.0, Vectors.dense(0.0, 0.0)),
+ (1.0, Vectors.dense(1.0, 2.0)),
+ (2.0, Vectors.dense(0.0, 0.0)),
+ (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"])
+
+ glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6)
+ model = glr.fit(df)
+ self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4))
+ self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4))
+
+ model2 = glr.setLinkPower(-1.0).fit(df)
+ self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4))
+ self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))
+
+ def test_offset(self):
+
+ df = self.spark.createDataFrame(
+ [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)),
+ (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)),
+ (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)),
+ (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"])
+
+ glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset")
+ model = glr.fit(df)
+ self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581],
+ atol=1E-4))
+ self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4))
+
+
+class LinearRegressionTest(SparkSessionTestCase):
+
+ def test_linear_regression_with_huber_loss(self):
+
+ data_path = "data/mllib/sample_linear_regression_data.txt"
+ df = self.spark.read.format("libsvm").load(data_path)
+
+ lir = LinearRegression(loss="huber", epsilon=2.0)
+ model = lir.fit(df)
+
+ expectedCoefficients = [0.136, 0.7648, -0.7761, 2.4236, 0.537,
+ 1.2612, -0.333, -0.5694, -0.6311, 0.6053]
+ expectedIntercept = 0.1607
+ expectedScale = 9.758
+
+ self.assertTrue(
+ np.allclose(model.coefficients.toArray(), expectedCoefficients, atol=1E-3))
+ self.assertTrue(np.isclose(model.intercept, expectedIntercept, atol=1E-3))
+ self.assertTrue(np.isclose(model.scale, expectedScale, atol=1E-3))
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_algorithms import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_base.py b/python/pyspark/ml/tests/test_base.py
new file mode 100644
index 0000000000000..31e3deb53046c
--- /dev/null
+++ b/python/pyspark/ml/tests/test_base.py
@@ -0,0 +1,77 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.types import DoubleType, IntegerType
+from pyspark.testing.mlutils import MockDataset, MockEstimator, MockUnaryTransformer, \
+ SparkSessionTestCase
+
+
+class UnaryTransformerTests(SparkSessionTestCase):
+
+ def test_unary_transformer_validate_input_type(self):
+ shiftVal = 3
+ transformer = MockUnaryTransformer(shiftVal=shiftVal) \
+ .setInputCol("input").setOutputCol("output")
+
+ # should not raise any errors
+ transformer.validateInputType(DoubleType())
+
+ with self.assertRaises(TypeError):
+ # passing the wrong input type should raise an error
+ transformer.validateInputType(IntegerType())
+
+ def test_unary_transformer_transform(self):
+ shiftVal = 3
+ transformer = MockUnaryTransformer(shiftVal=shiftVal) \
+ .setInputCol("input").setOutputCol("output")
+
+ df = self.spark.range(0, 10).toDF('input')
+ df = df.withColumn("input", df.input.cast(dataType="double"))
+
+ transformed_df = transformer.transform(df)
+ results = transformed_df.select("input", "output").collect()
+
+ for res in results:
+ self.assertEqual(res.input + shiftVal, res.output)
+
+
+class EstimatorTest(unittest.TestCase):
+
+ def testDefaultFitMultiple(self):
+ N = 4
+ data = MockDataset()
+ estimator = MockEstimator()
+ params = [{estimator.fake: i} for i in range(N)]
+ modelIter = estimator.fitMultiple(data, params)
+ indexList = []
+ for index, model in modelIter:
+ self.assertEqual(model.getFake(), index)
+ indexList.append(index)
+ self.assertEqual(sorted(indexList), list(range(N)))
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_base import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py
new file mode 100644
index 0000000000000..5438455a6f756
--- /dev/null
+++ b/python/pyspark/ml/tests/test_evaluation.py
@@ -0,0 +1,63 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import numpy as np
+
+from pyspark.ml.evaluation import ClusteringEvaluator, RegressionEvaluator
+from pyspark.ml.linalg import Vectors
+from pyspark.sql import Row
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class EvaluatorTests(SparkSessionTestCase):
+
+ def test_java_params(self):
+ """
+ This tests a bug fixed by SPARK-18274 which causes multiple copies
+ of a Params instance in Python to be linked to the same Java instance.
+ """
+ evaluator = RegressionEvaluator(metricName="r2")
+ df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)])
+ evaluator.evaluate(df)
+ self.assertEqual(evaluator._java_obj.getMetricName(), "r2")
+ evaluatorCopy = evaluator.copy({evaluator.metricName: "mae"})
+ evaluator.evaluate(df)
+ evaluatorCopy.evaluate(df)
+ self.assertEqual(evaluator._java_obj.getMetricName(), "r2")
+ self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae")
+
+ def test_clustering_evaluator_with_cosine_distance(self):
+ featureAndPredictions = map(lambda x: (Vectors.dense(x[0]), x[1]),
+ [([1.0, 1.0], 1.0), ([10.0, 10.0], 1.0), ([1.0, 0.5], 2.0),
+ ([10.0, 4.4], 2.0), ([-1.0, 1.0], 3.0), ([-100.0, 90.0], 3.0)])
+ dataset = self.spark.createDataFrame(featureAndPredictions, ["features", "prediction"])
+ evaluator = ClusteringEvaluator(predictionCol="prediction", distanceMeasure="cosine")
+ self.assertEqual(evaluator.getDistanceMeasure(), "cosine")
+ self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, atol=1e-5))
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_evaluation import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py
new file mode 100644
index 0000000000000..325feaba66957
--- /dev/null
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -0,0 +1,311 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import unittest
+
+if sys.version > '3':
+ basestring = str
+
+from pyspark.ml.feature import Binarizer, CountVectorizer, CountVectorizerModel, HashingTF, IDF, \
+ NGram, RFormula, StopWordsRemover, StringIndexer, StringIndexerModel, VectorSizeHint
+from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
+from pyspark.sql import Row
+from pyspark.testing.utils import QuietTest
+from pyspark.testing.mlutils import check_params, SparkSessionTestCase
+
+
+class FeatureTests(SparkSessionTestCase):
+
+ def test_binarizer(self):
+ b0 = Binarizer()
+ self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold])
+ self.assertTrue(all([~b0.isSet(p) for p in b0.params]))
+ self.assertTrue(b0.hasDefault(b0.threshold))
+ self.assertEqual(b0.getThreshold(), 0.0)
+ b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0)
+ self.assertTrue(all([b0.isSet(p) for p in b0.params]))
+ self.assertEqual(b0.getThreshold(), 1.0)
+ self.assertEqual(b0.getInputCol(), "input")
+ self.assertEqual(b0.getOutputCol(), "output")
+
+ b0c = b0.copy({b0.threshold: 2.0})
+ self.assertEqual(b0c.uid, b0.uid)
+ self.assertListEqual(b0c.params, b0.params)
+ self.assertEqual(b0c.getThreshold(), 2.0)
+
+ b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output")
+ self.assertNotEqual(b1.uid, b0.uid)
+ self.assertEqual(b1.getThreshold(), 2.0)
+ self.assertEqual(b1.getInputCol(), "input")
+ self.assertEqual(b1.getOutputCol(), "output")
+
+ def test_idf(self):
+ dataset = self.spark.createDataFrame([
+ (DenseVector([1.0, 2.0]),),
+ (DenseVector([0.0, 1.0]),),
+ (DenseVector([3.0, 0.2]),)], ["tf"])
+ idf0 = IDF(inputCol="tf")
+ self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol])
+ idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"})
+ self.assertEqual(idf0m.uid, idf0.uid,
+ "Model should inherit the UID from its parent estimator.")
+ output = idf0m.transform(dataset)
+ self.assertIsNotNone(output.head().idf)
+ # Test that parameters transferred to Python Model
+ check_params(self, idf0m)
+
+ def test_ngram(self):
+ dataset = self.spark.createDataFrame([
+ Row(input=["a", "b", "c", "d", "e"])])
+ ngram0 = NGram(n=4, inputCol="input", outputCol="output")
+ self.assertEqual(ngram0.getN(), 4)
+ self.assertEqual(ngram0.getInputCol(), "input")
+ self.assertEqual(ngram0.getOutputCol(), "output")
+ transformedDF = ngram0.transform(dataset)
+ self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"])
+
+ def test_stopwordsremover(self):
+ dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
+ stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
+ # Default
+ self.assertEqual(stopWordRemover.getInputCol(), "input")
+ transformedDF = stopWordRemover.transform(dataset)
+ self.assertEqual(transformedDF.head().output, ["panda"])
+ self.assertEqual(type(stopWordRemover.getStopWords()), list)
+ self.assertTrue(isinstance(stopWordRemover.getStopWords()[0], basestring))
+ # Custom
+ stopwords = ["panda"]
+ stopWordRemover.setStopWords(stopwords)
+ self.assertEqual(stopWordRemover.getInputCol(), "input")
+ self.assertEqual(stopWordRemover.getStopWords(), stopwords)
+ transformedDF = stopWordRemover.transform(dataset)
+ self.assertEqual(transformedDF.head().output, ["a"])
+ # with language selection
+ stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
+ dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
+ stopWordRemover.setStopWords(stopwords)
+ self.assertEqual(stopWordRemover.getStopWords(), stopwords)
+ transformedDF = stopWordRemover.transform(dataset)
+ self.assertEqual(transformedDF.head().output, [])
+ # with locale
+ stopwords = ["BELKİ"]
+ dataset = self.spark.createDataFrame([Row(input=["belki"])])
+ stopWordRemover.setStopWords(stopwords).setLocale("tr")
+ self.assertEqual(stopWordRemover.getStopWords(), stopwords)
+ transformedDF = stopWordRemover.transform(dataset)
+ self.assertEqual(transformedDF.head().output, [])
+
+ def test_count_vectorizer_with_binary(self):
+ dataset = self.spark.createDataFrame([
+ (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
+ (1, "a a".split(' '), SparseVector(3, {0: 1.0}),),
+ (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
+ (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"])
+ cv = CountVectorizer(binary=True, inputCol="words", outputCol="features")
+ model = cv.fit(dataset)
+
+ transformedList = model.transform(dataset).select("features", "expected").collect()
+
+ for r in transformedList:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
+ def test_count_vectorizer_with_maxDF(self):
+ dataset = self.spark.createDataFrame([
+ (0, "a b c d".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
+ (1, "a b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
+ (2, "a b".split(' '), SparseVector(3, {0: 1.0}),),
+ (3, "a".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"])
+ cv = CountVectorizer(inputCol="words", outputCol="features")
+ model1 = cv.setMaxDF(3).fit(dataset)
+ self.assertEqual(model1.vocabulary, ['b', 'c', 'd'])
+
+ transformedList1 = model1.transform(dataset).select("features", "expected").collect()
+
+ for r in transformedList1:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
+ model2 = cv.setMaxDF(0.75).fit(dataset)
+ self.assertEqual(model2.vocabulary, ['b', 'c', 'd'])
+
+ transformedList2 = model2.transform(dataset).select("features", "expected").collect()
+
+ for r in transformedList2:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
+ def test_count_vectorizer_from_vocab(self):
+ model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words",
+ outputCol="features", minTF=2)
+ self.assertEqual(model.vocabulary, ["a", "b", "c"])
+ self.assertEqual(model.getMinTF(), 2)
+
+ dataset = self.spark.createDataFrame([
+ (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),),
+ (1, "a a".split(' '), SparseVector(3, {0: 2.0}),),
+ (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"])
+
+ transformed_list = model.transform(dataset).select("features", "expected").collect()
+
+ for r in transformed_list:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
+ # Test an empty vocabulary
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"):
+ CountVectorizerModel.from_vocabulary([], inputCol="words")
+
+ # Test model with default settings can transform
+ model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words")
+ transformed_list = model_default.transform(dataset) \
+ .select(model_default.getOrDefault(model_default.outputCol)).collect()
+ self.assertEqual(len(transformed_list), 3)
+
+ def test_rformula_force_index_label(self):
+ df = self.spark.createDataFrame([
+ (1.0, 1.0, "a"),
+ (0.0, 2.0, "b"),
+ (1.0, 0.0, "a")], ["y", "x", "s"])
+ # Does not index label by default since it's numeric type.
+ rf = RFormula(formula="y ~ x + s")
+ model = rf.fit(df)
+ transformedDF = model.transform(df)
+ self.assertEqual(transformedDF.head().label, 1.0)
+ # Force to index label.
+ rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True)
+ model2 = rf2.fit(df)
+ transformedDF2 = model2.transform(df)
+ self.assertEqual(transformedDF2.head().label, 0.0)
+
+ def test_rformula_string_indexer_order_type(self):
+ df = self.spark.createDataFrame([
+ (1.0, 1.0, "a"),
+ (0.0, 2.0, "b"),
+ (1.0, 0.0, "a")], ["y", "x", "s"])
+ rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc")
+ self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc')
+ transformedDF = rf.fit(df).transform(df)
+ observed = transformedDF.select("features").collect()
+ expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]]
+ for i in range(0, len(expected)):
+ self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
+
+ def test_string_indexer_handle_invalid(self):
+ df = self.spark.createDataFrame([
+ (0, "a"),
+ (1, "d"),
+ (2, None)], ["id", "label"])
+
+ si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep",
+ stringOrderType="alphabetAsc")
+ model1 = si1.fit(df)
+ td1 = model1.transform(df)
+ actual1 = td1.select("id", "indexed").collect()
+ expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)]
+ self.assertEqual(actual1, expected1)
+
+ si2 = si1.setHandleInvalid("skip")
+ model2 = si2.fit(df)
+ td2 = model2.transform(df)
+ actual2 = td2.select("id", "indexed").collect()
+ expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
+ self.assertEqual(actual2, expected2)
+
+ def test_string_indexer_from_labels(self):
+ model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label",
+ outputCol="indexed", handleInvalid="keep")
+ self.assertEqual(model.labels, ["a", "b", "c"])
+
+ df1 = self.spark.createDataFrame([
+ (0, "a"),
+ (1, "c"),
+ (2, None),
+ (3, "b"),
+ (4, "b")], ["id", "label"])
+
+ result1 = model.transform(df1)
+ actual1 = result1.select("id", "indexed").collect()
+ expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, indexed=3.0),
+ Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)]
+ self.assertEqual(actual1, expected1)
+
+ model_empty_labels = StringIndexerModel.from_labels(
+ [], inputCol="label", outputCol="indexed", handleInvalid="keep")
+ actual2 = model_empty_labels.transform(df1).select("id", "indexed").collect()
+ expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, indexed=0.0),
+ Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)]
+ self.assertEqual(actual2, expected2)
+
+ # Test model with default settings can transform
+ model_default = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label")
+ df2 = self.spark.createDataFrame([
+ (0, "a"),
+ (1, "c"),
+ (2, "b"),
+ (3, "b"),
+ (4, "b")], ["id", "label"])
+ transformed_list = model_default.transform(df2) \
+ .select(model_default.getOrDefault(model_default.outputCol)).collect()
+ self.assertEqual(len(transformed_list), 5)
+
+ def test_vector_size_hint(self):
+ df = self.spark.createDataFrame(
+ [(0, Vectors.dense([0.0, 10.0, 0.5])),
+ (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])),
+ (2, Vectors.dense([2.0, 12.0]))],
+ ["id", "vector"])
+
+ sizeHint = VectorSizeHint(
+ inputCol="vector",
+ handleInvalid="skip")
+ sizeHint.setSize(3)
+ self.assertEqual(sizeHint.getSize(), 3)
+
+ output = sizeHint.transform(df).head().vector
+ expected = DenseVector([0.0, 10.0, 0.5])
+ self.assertEqual(output, expected)
+
+
+class HashingTFTest(SparkSessionTestCase):
+
+ def test_apply_binary_term_freqs(self):
+
+ df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"])
+ n = 10
+ hashingTF = HashingTF()
+ hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True)
+ output = hashingTF.transform(df)
+ features = output.select("features").first().features.toArray()
+ expected = Vectors.dense([1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).toArray()
+ for i in range(0, n):
+ self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) +
+ ": expected " + str(expected[i]) + ", got " + str(features[i]))
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_feature import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_image.py b/python/pyspark/ml/tests/test_image.py
new file mode 100644
index 0000000000000..4c280a4a67894
--- /dev/null
+++ b/python/pyspark/ml/tests/test_image.py
@@ -0,0 +1,110 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import py4j
+
+from pyspark.ml.image import ImageSchema
+from pyspark.testing.mlutils import PySparkTestCase, SparkSessionTestCase
+from pyspark.sql import HiveContext, Row
+from pyspark.testing.utils import QuietTest
+
+
+class ImageReaderTest(SparkSessionTestCase):
+
+ def test_read_images(self):
+ data_path = 'data/mllib/images/origin/kittens'
+ df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
+ self.assertEqual(df.count(), 4)
+ first_row = df.take(1)[0][0]
+ array = ImageSchema.toNDArray(first_row)
+ self.assertEqual(len(array), first_row[1])
+ self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), first_row)
+ self.assertEqual(df.schema, ImageSchema.imageSchema)
+ self.assertEqual(df.schema["image"].dataType, ImageSchema.columnSchema)
+ expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 'CV_8UC4': 24}
+ self.assertEqual(ImageSchema.ocvTypes, expected)
+ expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data']
+ self.assertEqual(ImageSchema.imageFields, expected)
+ self.assertEqual(ImageSchema.undefinedImageType, "Undefined")
+
+ with QuietTest(self.sc):
+ self.assertRaisesRegexp(
+ TypeError,
+ "image argument should be pyspark.sql.types.Row; however",
+ lambda: ImageSchema.toNDArray("a"))
+
+ with QuietTest(self.sc):
+ self.assertRaisesRegexp(
+ ValueError,
+ "image argument should have attributes specified in",
+ lambda: ImageSchema.toNDArray(Row(a=1)))
+
+ with QuietTest(self.sc):
+ self.assertRaisesRegexp(
+ TypeError,
+ "array argument should be numpy.ndarray; however, it got",
+ lambda: ImageSchema.toImage("a"))
+
+
+class ImageReaderTest2(PySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(ImageReaderTest2, cls).setUpClass()
+ cls.hive_available = True
+ # Note that here we enable Hive's support.
+ cls.spark = None
+ try:
+ cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
+ except py4j.protocol.Py4JError:
+ cls.tearDownClass()
+ cls.hive_available = False
+ except TypeError:
+ cls.tearDownClass()
+ cls.hive_available = False
+ if cls.hive_available:
+ cls.spark = HiveContext._createForTesting(cls.sc)
+
+ def setUp(self):
+ if not self.hive_available:
+ self.skipTest("Hive is not available.")
+
+ @classmethod
+ def tearDownClass(cls):
+ super(ImageReaderTest2, cls).tearDownClass()
+ if cls.spark is not None:
+ cls.spark.sparkSession.stop()
+ cls.spark = None
+
+ def test_read_images_multiple_times(self):
+ # This test case is to check if `ImageSchema.readImages` tries to
+ # initiate Hive client multiple times. See SPARK-22651.
+ data_path = 'data/mllib/images/origin/kittens'
+ ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
+ ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_image import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py
new file mode 100644
index 0000000000000..71cad5d7f5ad7
--- /dev/null
+++ b/python/pyspark/ml/tests/test_linalg.py
@@ -0,0 +1,384 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import array as pyarray
+
+from numpy import arange, array, array_equal, inf, ones, tile, zeros
+
+from pyspark.ml.linalg import DenseMatrix, DenseVector, MatrixUDT, SparseMatrix, SparseVector, \
+ Vector, VectorUDT, Vectors
+from pyspark.testing.mllibutils import make_serializer, MLlibTestCase
+from pyspark.sql import Row
+
+
+ser = make_serializer()
+
+
+def _squared_distance(a, b):
+ if isinstance(a, Vector):
+ return a.squared_distance(b)
+ else:
+ return b.squared_distance(a)
+
+
+class VectorTests(MLlibTestCase):
+
+ def _test_serialize(self, v):
+ self.assertEqual(v, ser.loads(ser.dumps(v)))
+ jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v)))
+ nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec)))
+ self.assertEqual(v, nv)
+ vs = [v] * 100
+ jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs)))
+ nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs)))
+ self.assertEqual(vs, nvs)
+
+ def test_serialize(self):
+ self._test_serialize(DenseVector(range(10)))
+ self._test_serialize(DenseVector(array([1., 2., 3., 4.])))
+ self._test_serialize(DenseVector(pyarray.array('d', range(10))))
+ self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
+ self._test_serialize(SparseVector(3, {}))
+ self._test_serialize(DenseMatrix(2, 3, range(6)))
+ sm1 = SparseMatrix(
+ 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
+ self._test_serialize(sm1)
+
+ def test_dot(self):
+ sv = SparseVector(4, {1: 1, 3: 2})
+ dv = DenseVector(array([1., 2., 3., 4.]))
+ lst = DenseVector([1, 2, 3, 4])
+ mat = array([[1., 2., 3., 4.],
+ [1., 2., 3., 4.],
+ [1., 2., 3., 4.],
+ [1., 2., 3., 4.]])
+ arr = pyarray.array('d', [0, 1, 2, 3])
+ self.assertEqual(10.0, sv.dot(dv))
+ self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat)))
+ self.assertEqual(30.0, dv.dot(dv))
+ self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat)))
+ self.assertEqual(30.0, lst.dot(dv))
+ self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat)))
+ self.assertEqual(7.0, sv.dot(arr))
+
+ def test_squared_distance(self):
+ sv = SparseVector(4, {1: 1, 3: 2})
+ dv = DenseVector(array([1., 2., 3., 4.]))
+ lst = DenseVector([4, 3, 2, 1])
+ lst1 = [4, 3, 2, 1]
+ arr = pyarray.array('d', [0, 2, 1, 3])
+ narr = array([0, 2, 1, 3])
+ self.assertEqual(15.0, _squared_distance(sv, dv))
+ self.assertEqual(25.0, _squared_distance(sv, lst))
+ self.assertEqual(20.0, _squared_distance(dv, lst))
+ self.assertEqual(15.0, _squared_distance(dv, sv))
+ self.assertEqual(25.0, _squared_distance(lst, sv))
+ self.assertEqual(20.0, _squared_distance(lst, dv))
+ self.assertEqual(0.0, _squared_distance(sv, sv))
+ self.assertEqual(0.0, _squared_distance(dv, dv))
+ self.assertEqual(0.0, _squared_distance(lst, lst))
+ self.assertEqual(25.0, _squared_distance(sv, lst1))
+ self.assertEqual(3.0, _squared_distance(sv, arr))
+ self.assertEqual(3.0, _squared_distance(sv, narr))
+
+ def test_hash(self):
+ v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
+ v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v4 = SparseVector(4, [(1, 1.0), (3, 2.5)])
+ self.assertEqual(hash(v1), hash(v2))
+ self.assertEqual(hash(v1), hash(v3))
+ self.assertEqual(hash(v2), hash(v3))
+ self.assertFalse(hash(v1) == hash(v4))
+ self.assertFalse(hash(v2) == hash(v4))
+
+ def test_eq(self):
+ v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
+ v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
+ v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
+ v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
+ self.assertEqual(v1, v2)
+ self.assertEqual(v1, v3)
+ self.assertFalse(v2 == v4)
+ self.assertFalse(v1 == v5)
+ self.assertFalse(v1 == v6)
+
+ def test_equals(self):
+ indices = [1, 2, 4]
+ values = [1., 3., 2.]
+ self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.]))
+
+ def test_conversion(self):
+ # numpy arrays should be automatically upcast to float64
+ # tests for fix of [SPARK-5089]
+ v = array([1, 2, 3, 4], dtype='float64')
+ dv = DenseVector(v)
+ self.assertTrue(dv.array.dtype == 'float64')
+ v = array([1, 2, 3, 4], dtype='float32')
+ dv = DenseVector(v)
+ self.assertTrue(dv.array.dtype == 'float64')
+
+ def test_sparse_vector_indexing(self):
+ sv = SparseVector(5, {1: 1, 3: 2})
+ self.assertEqual(sv[0], 0.)
+ self.assertEqual(sv[3], 2.)
+ self.assertEqual(sv[1], 1.)
+ self.assertEqual(sv[2], 0.)
+ self.assertEqual(sv[4], 0.)
+ self.assertEqual(sv[-1], 0.)
+ self.assertEqual(sv[-2], 2.)
+ self.assertEqual(sv[-3], 0.)
+ self.assertEqual(sv[-5], 0.)
+ for ind in [5, -6]:
+ self.assertRaises(IndexError, sv.__getitem__, ind)
+ for ind in [7.8, '1']:
+ self.assertRaises(TypeError, sv.__getitem__, ind)
+
+ zeros = SparseVector(4, {})
+ self.assertEqual(zeros[0], 0.0)
+ self.assertEqual(zeros[3], 0.0)
+ for ind in [4, -5]:
+ self.assertRaises(IndexError, zeros.__getitem__, ind)
+
+ empty = SparseVector(0, {})
+ for ind in [-1, 0, 1]:
+ self.assertRaises(IndexError, empty.__getitem__, ind)
+
+ def test_sparse_vector_iteration(self):
+ self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0])
+ self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0])
+
+ def test_matrix_indexing(self):
+ mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
+ expected = [[0, 6], [1, 8], [4, 10]]
+ for i in range(3):
+ for j in range(2):
+ self.assertEqual(mat[i, j], expected[i][j])
+
+ for i, j in [(-1, 0), (4, 1), (3, 4)]:
+ self.assertRaises(IndexError, mat.__getitem__, (i, j))
+
+ def test_repr_dense_matrix(self):
+ mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
+ self.assertTrue(
+ repr(mat),
+ 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
+
+ mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True)
+ self.assertTrue(
+ repr(mat),
+ 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
+
+ mat = DenseMatrix(6, 3, zeros(18))
+ self.assertTrue(
+ repr(mat),
+ 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)')
+
+ def test_repr_sparse_matrix(self):
+ sm1t = SparseMatrix(
+ 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
+ isTransposed=True)
+ self.assertTrue(
+ repr(sm1t),
+ 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)')
+
+ indices = tile(arange(6), 3)
+ values = ones(18)
+ sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values)
+ self.assertTrue(
+ repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \
+ [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \
+ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)")
+
+ self.assertTrue(
+ str(sm),
+ "6 X 3 CSCMatrix\n\
+ (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\
+ (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\
+ (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..")
+
+ sm = SparseMatrix(1, 18, zeros(19), [], [])
+ self.assertTrue(
+ repr(sm),
+ 'SparseMatrix(1, 18, \
+ [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)')
+
+ def test_sparse_matrix(self):
+ # Test sparse matrix creation.
+ sm1 = SparseMatrix(
+ 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
+ self.assertEqual(sm1.numRows, 3)
+ self.assertEqual(sm1.numCols, 4)
+ self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4])
+ self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2])
+ self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0])
+ self.assertTrue(
+ repr(sm1),
+ 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)')
+
+ # Test indexing
+ expected = [
+ [0, 0, 0, 0],
+ [1, 0, 4, 0],
+ [2, 0, 5, 0]]
+
+ for i in range(3):
+ for j in range(4):
+ self.assertEqual(expected[i][j], sm1[i, j])
+ self.assertTrue(array_equal(sm1.toArray(), expected))
+
+ for i, j in [(-1, 1), (4, 3), (3, 5)]:
+ self.assertRaises(IndexError, sm1.__getitem__, (i, j))
+
+ # Test conversion to dense and sparse.
+ smnew = sm1.toDense().toSparse()
+ self.assertEqual(sm1.numRows, smnew.numRows)
+ self.assertEqual(sm1.numCols, smnew.numCols)
+ self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs))
+ self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices))
+ self.assertTrue(array_equal(sm1.values, smnew.values))
+
+ sm1t = SparseMatrix(
+ 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
+ isTransposed=True)
+ self.assertEqual(sm1t.numRows, 3)
+ self.assertEqual(sm1t.numCols, 4)
+ self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5])
+ self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2])
+ self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0])
+
+ expected = [
+ [3, 2, 0, 0],
+ [0, 0, 4, 0],
+ [9, 0, 8, 0]]
+
+ for i in range(3):
+ for j in range(4):
+ self.assertEqual(expected[i][j], sm1t[i, j])
+ self.assertTrue(array_equal(sm1t.toArray(), expected))
+
+ def test_dense_matrix_is_transposed(self):
+ mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True)
+ mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9])
+ self.assertEqual(mat1, mat)
+
+ expected = [[0, 4], [1, 6], [3, 9]]
+ for i in range(3):
+ for j in range(2):
+ self.assertEqual(mat1[i, j], expected[i][j])
+ self.assertTrue(array_equal(mat1.toArray(), expected))
+
+ sm = mat1.toSparse()
+ self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2]))
+ self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
+ self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
+
+ def test_norms(self):
+ a = DenseVector([0, 2, 3, -1])
+ self.assertAlmostEqual(a.norm(2), 3.742, 3)
+ self.assertTrue(a.norm(1), 6)
+ self.assertTrue(a.norm(inf), 3)
+ a = SparseVector(4, [0, 2], [3, -4])
+ self.assertAlmostEqual(a.norm(2), 5)
+ self.assertTrue(a.norm(1), 7)
+ self.assertTrue(a.norm(inf), 4)
+
+ tmp = SparseVector(4, [0, 2], [3, 0])
+ self.assertEqual(tmp.numNonzeros(), 1)
+
+
+class VectorUDTTests(MLlibTestCase):
+
+ dv0 = DenseVector([])
+ dv1 = DenseVector([1.0, 2.0])
+ sv0 = SparseVector(2, [], [])
+ sv1 = SparseVector(2, [1], [2.0])
+ udt = VectorUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for v in [self.dv0, self.dv1, self.sv0, self.sv1]:
+ self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
+
+ def test_infer_schema(self):
+ rdd = self.sc.parallelize([Row(label=1.0, features=self.dv1),
+ Row(label=0.0, features=self.sv1)])
+ df = rdd.toDF()
+ schema = df.schema
+ field = [f for f in schema.fields if f.name == "features"][0]
+ self.assertEqual(field.dataType, self.udt)
+ vectors = df.rdd.map(lambda p: p.features).collect()
+ self.assertEqual(len(vectors), 2)
+ for v in vectors:
+ if isinstance(v, SparseVector):
+ self.assertEqual(v, self.sv1)
+ elif isinstance(v, DenseVector):
+ self.assertEqual(v, self.dv1)
+ else:
+ raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
+
+
+class MatrixUDTTests(MLlibTestCase):
+
+ dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
+ dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True)
+ sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0])
+ sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True)
+ udt = MatrixUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for m in [self.dm1, self.dm2, self.sm1, self.sm2]:
+ self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
+
+ def test_infer_schema(self):
+ rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
+ df = rdd.toDF()
+ schema = df.schema
+ self.assertTrue(schema.fields[1].dataType, self.udt)
+ matrices = df.rdd.map(lambda x: x._2).collect()
+ self.assertEqual(len(matrices), 2)
+ for m in matrices:
+ if isinstance(m, DenseMatrix):
+ self.assertTrue(m, self.dm1)
+ elif isinstance(m, SparseMatrix):
+ self.assertTrue(m, self.sm1)
+ else:
+ raise ValueError("Expected a matrix but got type %r" % type(m))
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_linalg import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py
new file mode 100644
index 0000000000000..17c1b0bf65dde
--- /dev/null
+++ b/python/pyspark/ml/tests/test_param.py
@@ -0,0 +1,366 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import inspect
+import sys
+import array as pyarray
+import unittest
+
+import numpy as np
+
+from pyspark import keyword_only
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.clustering import KMeans
+from pyspark.ml.feature import Binarizer, Bucketizer, ElementwiseProduct, IndexToString, \
+ VectorSlicer, Word2Vec
+from pyspark.ml.linalg import DenseVector, SparseVector
+from pyspark.ml.param import Param, Params, TypeConverters
+from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed
+from pyspark.ml.wrapper import JavaParams
+from pyspark.testing.mlutils import check_params, PySparkTestCase, SparkSessionTestCase
+
+
+if sys.version > '3':
+ xrange = range
+
+
+class ParamTypeConversionTests(PySparkTestCase):
+ """
+ Test that param type conversion happens.
+ """
+
+ def test_int(self):
+ lr = LogisticRegression(maxIter=5.0)
+ self.assertEqual(lr.getMaxIter(), 5)
+ self.assertTrue(type(lr.getMaxIter()) == int)
+ self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt"))
+ self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1))
+
+ def test_float(self):
+ lr = LogisticRegression(tol=1)
+ self.assertEqual(lr.getTol(), 1.0)
+ self.assertTrue(type(lr.getTol()) == float)
+ self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat"))
+
+ def test_vector(self):
+ ewp = ElementwiseProduct(scalingVec=[1, 3])
+ self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0]))
+ ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4]))
+ self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4]))
+ self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"]))
+
+ def test_list(self):
+ l = [0, 1]
+ for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), range(len(l)), l),
+ pyarray.array('l', l), xrange(2), tuple(l)]:
+ converted = TypeConverters.toList(lst_like)
+ self.assertEqual(type(converted), list)
+ self.assertListEqual(converted, l)
+
+ def test_list_int(self):
+ for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]),
+ SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0),
+ pyarray.array('d', [1.0, 2.0])]:
+ vs = VectorSlicer(indices=indices)
+ self.assertListEqual(vs.getIndices(), [1, 2])
+ self.assertTrue(all([type(v) == int for v in vs.getIndices()]))
+ self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"]))
+
+ def test_list_float(self):
+ b = Bucketizer(splits=[1, 4])
+ self.assertEqual(b.getSplits(), [1.0, 4.0])
+ self.assertTrue(all([type(v) == float for v in b.getSplits()]))
+ self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0]))
+
+ def test_list_string(self):
+ for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]:
+ idx_to_string = IndexToString(labels=labels)
+ self.assertListEqual(idx_to_string.getLabels(), ['a', 'b'])
+ self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2]))
+
+ def test_string(self):
+ lr = LogisticRegression()
+ for col in ['features', u'features', np.str_('features')]:
+ lr.setFeaturesCol(col)
+ self.assertEqual(lr.getFeaturesCol(), 'features')
+ self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3))
+
+ def test_bool(self):
+ self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1))
+ self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false"))
+
+
+class TestParams(HasMaxIter, HasInputCol, HasSeed):
+ """
+ A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
+ """
+ @keyword_only
+ def __init__(self, seed=None):
+ super(TestParams, self).__init__()
+ self._setDefault(maxIter=10)
+ kwargs = self._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, seed=None):
+ """
+ setParams(self, seed=None)
+ Sets params for this test.
+ """
+ kwargs = self._input_kwargs
+ return self._set(**kwargs)
+
+
+class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
+ """
+ A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
+ """
+ @keyword_only
+ def __init__(self, seed=None):
+ super(OtherTestParams, self).__init__()
+ self._setDefault(maxIter=10)
+ kwargs = self._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, seed=None):
+ """
+ setParams(self, seed=None)
+ Sets params for this test.
+ """
+ kwargs = self._input_kwargs
+ return self._set(**kwargs)
+
+
+class HasThrowableProperty(Params):
+
+ def __init__(self):
+ super(HasThrowableProperty, self).__init__()
+ self.p = Param(self, "none", "empty param")
+
+ @property
+ def test_property(self):
+ raise RuntimeError("Test property to raise error when invoked")
+
+
+class ParamTests(SparkSessionTestCase):
+
+ def test_copy_new_parent(self):
+ testParams = TestParams()
+ # Copying an instantiated param should fail
+ with self.assertRaises(ValueError):
+ testParams.maxIter._copy_new_parent(testParams)
+ # Copying a dummy param should succeed
+ TestParams.maxIter._copy_new_parent(testParams)
+ maxIter = testParams.maxIter
+ self.assertEqual(maxIter.name, "maxIter")
+ self.assertEqual(maxIter.doc, "max number of iterations (>= 0).")
+ self.assertTrue(maxIter.parent == testParams.uid)
+
+ def test_param(self):
+ testParams = TestParams()
+ maxIter = testParams.maxIter
+ self.assertEqual(maxIter.name, "maxIter")
+ self.assertEqual(maxIter.doc, "max number of iterations (>= 0).")
+ self.assertTrue(maxIter.parent == testParams.uid)
+
+ def test_hasparam(self):
+ testParams = TestParams()
+ self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
+ self.assertFalse(testParams.hasParam("notAParameter"))
+ self.assertTrue(testParams.hasParam(u"maxIter"))
+
+ def test_resolveparam(self):
+ testParams = TestParams()
+ self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter)
+ self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter)
+
+ self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter)
+ if sys.version_info[0] >= 3:
+ # In Python 3, it is allowed to get/set attributes with non-ascii characters.
+ e_cls = AttributeError
+ else:
+ e_cls = UnicodeEncodeError
+ self.assertRaises(e_cls, lambda: testParams._resolveParam(u"아"))
+
+ def test_params(self):
+ testParams = TestParams()
+ maxIter = testParams.maxIter
+ inputCol = testParams.inputCol
+ seed = testParams.seed
+
+ params = testParams.params
+ self.assertEqual(params, [inputCol, maxIter, seed])
+
+ self.assertTrue(testParams.hasParam(maxIter.name))
+ self.assertTrue(testParams.hasDefault(maxIter))
+ self.assertFalse(testParams.isSet(maxIter))
+ self.assertTrue(testParams.isDefined(maxIter))
+ self.assertEqual(testParams.getMaxIter(), 10)
+ testParams.setMaxIter(100)
+ self.assertTrue(testParams.isSet(maxIter))
+ self.assertEqual(testParams.getMaxIter(), 100)
+
+ self.assertTrue(testParams.hasParam(inputCol.name))
+ self.assertFalse(testParams.hasDefault(inputCol))
+ self.assertFalse(testParams.isSet(inputCol))
+ self.assertFalse(testParams.isDefined(inputCol))
+ with self.assertRaises(KeyError):
+ testParams.getInputCol()
+
+ otherParam = Param(Params._dummy(), "otherParam", "Parameter used to test that " +
+ "set raises an error for a non-member parameter.",
+ typeConverter=TypeConverters.toString)
+ with self.assertRaises(ValueError):
+ testParams.set(otherParam, "value")
+
+ # Since the default is normally random, set it to a known number for debug str
+ testParams._setDefault(seed=41)
+ testParams.setSeed(43)
+
+ self.assertEqual(
+ testParams.explainParams(),
+ "\n".join(["inputCol: input column name. (undefined)",
+ "maxIter: max number of iterations (>= 0). (default: 10, current: 100)",
+ "seed: random seed. (default: 41, current: 43)"]))
+
+ def test_kmeans_param(self):
+ algo = KMeans()
+ self.assertEqual(algo.getInitMode(), "k-means||")
+ algo.setK(10)
+ self.assertEqual(algo.getK(), 10)
+ algo.setInitSteps(10)
+ self.assertEqual(algo.getInitSteps(), 10)
+ self.assertEqual(algo.getDistanceMeasure(), "euclidean")
+ algo.setDistanceMeasure("cosine")
+ self.assertEqual(algo.getDistanceMeasure(), "cosine")
+
+ def test_hasseed(self):
+ noSeedSpecd = TestParams()
+ withSeedSpecd = TestParams(seed=42)
+ other = OtherTestParams()
+ # Check that we no longer use 42 as the magic number
+ self.assertNotEqual(noSeedSpecd.getSeed(), 42)
+ origSeed = noSeedSpecd.getSeed()
+ # Check that we only compute the seed once
+ self.assertEqual(noSeedSpecd.getSeed(), origSeed)
+ # Check that a specified seed is honored
+ self.assertEqual(withSeedSpecd.getSeed(), 42)
+ # Check that a different class has a different seed
+ self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed())
+
+ def test_param_property_error(self):
+ param_store = HasThrowableProperty()
+ self.assertRaises(RuntimeError, lambda: param_store.test_property)
+ params = param_store.params # should not invoke the property 'test_property'
+ self.assertEqual(len(params), 1)
+
+ def test_word2vec_param(self):
+ model = Word2Vec().setWindowSize(6)
+ # Check windowSize is set properly
+ self.assertEqual(model.getWindowSize(), 6)
+
+ def test_copy_param_extras(self):
+ tp = TestParams(seed=42)
+ extra = {tp.getParam(TestParams.inputCol.name): "copy_input"}
+ tp_copy = tp.copy(extra=extra)
+ self.assertEqual(tp.uid, tp_copy.uid)
+ self.assertEqual(tp.params, tp_copy.params)
+ for k, v in extra.items():
+ self.assertTrue(tp_copy.isDefined(k))
+ self.assertEqual(tp_copy.getOrDefault(k), v)
+ copied_no_extra = {}
+ for k, v in tp_copy._paramMap.items():
+ if k not in extra:
+ copied_no_extra[k] = v
+ self.assertEqual(tp._paramMap, copied_no_extra)
+ self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap)
+
+ def test_logistic_regression_check_thresholds(self):
+ self.assertIsInstance(
+ LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]),
+ LogisticRegression
+ )
+
+ self.assertRaisesRegexp(
+ ValueError,
+ "Logistic Regression getThreshold found inconsistent.*$",
+ LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
+ )
+
+ def test_preserve_set_state(self):
+ dataset = self.spark.createDataFrame([(0.5,)], ["data"])
+ binarizer = Binarizer(inputCol="data")
+ self.assertFalse(binarizer.isSet("threshold"))
+ binarizer.transform(dataset)
+ binarizer._transfer_params_from_java()
+ self.assertFalse(binarizer.isSet("threshold"),
+ "Params not explicitly set should remain unset after transform")
+
+ def test_default_params_transferred(self):
+ dataset = self.spark.createDataFrame([(0.5,)], ["data"])
+ binarizer = Binarizer(inputCol="data")
+ # intentionally change the pyspark default, but don't set it
+ binarizer._defaultParamMap[binarizer.outputCol] = "my_default"
+ result = binarizer.transform(dataset).select("my_default").collect()
+ self.assertFalse(binarizer.isSet(binarizer.outputCol))
+ self.assertEqual(result[0][0], 1.0)
+
+
+class DefaultValuesTests(PySparkTestCase):
+ """
+ Test :py:class:`JavaParams` classes to see if their default Param values match
+ those in their Scala counterparts.
+ """
+
+ def test_java_params(self):
+ import pyspark.ml.feature
+ import pyspark.ml.classification
+ import pyspark.ml.clustering
+ import pyspark.ml.evaluation
+ import pyspark.ml.pipeline
+ import pyspark.ml.recommendation
+ import pyspark.ml.regression
+
+ modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering,
+ pyspark.ml.evaluation, pyspark.ml.pipeline, pyspark.ml.recommendation,
+ pyspark.ml.regression]
+ for module in modules:
+ for name, cls in inspect.getmembers(module, inspect.isclass):
+ if not name.endswith('Model') and not name.endswith('Params') \
+ and issubclass(cls, JavaParams) and not inspect.isabstract(cls):
+ # NOTE: disable check_params_exist until there is parity with Scala API
+ check_params(self, cls(), check_params_exist=False)
+
+ # Additional classes that need explicit construction
+ from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel
+ check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'),
+ check_params_exist=False)
+ check_params(self, StringIndexerModel.from_labels(['a', 'b'], 'input'),
+ check_params_exist=False)
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_param import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py
new file mode 100644
index 0000000000000..34d687039ab34
--- /dev/null
+++ b/python/pyspark/ml/tests/test_persistence.py
@@ -0,0 +1,361 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+from shutil import rmtree
+import tempfile
+import unittest
+
+from pyspark.ml import Transformer
+from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \
+ OneVsRestModel
+from pyspark.ml.feature import Binarizer, HashingTF, PCA
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.param import Params
+from pyspark.ml.pipeline import Pipeline, PipelineModel
+from pyspark.ml.regression import DecisionTreeRegressor, LinearRegression
+from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWriter
+from pyspark.ml.wrapper import JavaParams
+from pyspark.testing.mlutils import MockUnaryTransformer, SparkSessionTestCase
+
+
+class PersistenceTest(SparkSessionTestCase):
+
+ def test_linear_regression(self):
+ lr = LinearRegression(maxIter=1)
+ path = tempfile.mkdtemp()
+ lr_path = path + "/lr"
+ lr.save(lr_path)
+ lr2 = LinearRegression.load(lr_path)
+ self.assertEqual(lr.uid, lr2.uid)
+ self.assertEqual(type(lr.uid), type(lr2.uid))
+ self.assertEqual(lr2.uid, lr2.maxIter.parent,
+ "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)"
+ % (lr2.uid, lr2.maxIter.parent))
+ self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
+ "Loaded LinearRegression instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+ def test_linear_regression_pmml_basic(self):
+ # Most of the validation is done in the Scala side, here we just check
+ # that we output text rather than parquet (e.g. that the format flag
+ # was respected).
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ lr = LinearRegression(maxIter=1)
+ model = lr.fit(df)
+ path = tempfile.mkdtemp()
+ lr_path = path + "/lr-pmml"
+ model.write().format("pmml").save(lr_path)
+ pmml_text_list = self.sc.textFile(lr_path).collect()
+ pmml_text = "\n".join(pmml_text_list)
+ self.assertIn("Apache Spark", pmml_text)
+ self.assertIn("PMML", pmml_text)
+
+ def test_logistic_regression(self):
+ lr = LogisticRegression(maxIter=1)
+ path = tempfile.mkdtemp()
+ lr_path = path + "/logreg"
+ lr.save(lr_path)
+ lr2 = LogisticRegression.load(lr_path)
+ self.assertEqual(lr2.uid, lr2.maxIter.parent,
+ "Loaded LogisticRegression instance uid (%s) "
+ "did not match Param's uid (%s)"
+ % (lr2.uid, lr2.maxIter.parent))
+ self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
+ "Loaded LogisticRegression instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+ def _compare_params(self, m1, m2, param):
+ """
+ Compare 2 ML Params instances for the given param, and assert both have the same param value
+ and parent. The param must be a parameter of m1.
+ """
+ # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap.
+ if m1.isDefined(param):
+ paramValue1 = m1.getOrDefault(param)
+ paramValue2 = m2.getOrDefault(m2.getParam(param.name))
+ if isinstance(paramValue1, Params):
+ self._compare_pipelines(paramValue1, paramValue2)
+ else:
+ self.assertEqual(paramValue1, paramValue2) # for general types param
+ # Assert parents are equal
+ self.assertEqual(param.parent, m2.getParam(param.name).parent)
+ else:
+ # If m1 is not defined param, then m2 should not, too. See SPARK-14931.
+ self.assertFalse(m2.isDefined(m2.getParam(param.name)))
+
+ def _compare_pipelines(self, m1, m2):
+ """
+ Compare 2 ML types, asserting that they are equivalent.
+ This currently supports:
+ - basic types
+ - Pipeline, PipelineModel
+ - OneVsRest, OneVsRestModel
+ This checks:
+ - uid
+ - type
+ - Param values and parents
+ """
+ self.assertEqual(m1.uid, m2.uid)
+ self.assertEqual(type(m1), type(m2))
+ if isinstance(m1, JavaParams) or isinstance(m1, Transformer):
+ self.assertEqual(len(m1.params), len(m2.params))
+ for p in m1.params:
+ self._compare_params(m1, m2, p)
+ elif isinstance(m1, Pipeline):
+ self.assertEqual(len(m1.getStages()), len(m2.getStages()))
+ for s1, s2 in zip(m1.getStages(), m2.getStages()):
+ self._compare_pipelines(s1, s2)
+ elif isinstance(m1, PipelineModel):
+ self.assertEqual(len(m1.stages), len(m2.stages))
+ for s1, s2 in zip(m1.stages, m2.stages):
+ self._compare_pipelines(s1, s2)
+ elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel):
+ for p in m1.params:
+ self._compare_params(m1, m2, p)
+ if isinstance(m1, OneVsRestModel):
+ self.assertEqual(len(m1.models), len(m2.models))
+ for x, y in zip(m1.models, m2.models):
+ self._compare_pipelines(x, y)
+ else:
+ raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
+
+ def test_pipeline_persistence(self):
+ """
+ Pipeline[HashingTF, PCA]
+ """
+ temp_path = tempfile.mkdtemp()
+
+ try:
+ df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+ tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
+ pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+ pl = Pipeline(stages=[tf, pca])
+ model = pl.fit(df)
+
+ pipeline_path = temp_path + "/pipeline"
+ pl.save(pipeline_path)
+ loaded_pipeline = Pipeline.load(pipeline_path)
+ self._compare_pipelines(pl, loaded_pipeline)
+
+ model_path = temp_path + "/pipeline-model"
+ model.save(model_path)
+ loaded_model = PipelineModel.load(model_path)
+ self._compare_pipelines(model, loaded_model)
+ finally:
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
+
+ def test_nested_pipeline_persistence(self):
+ """
+ Pipeline[HashingTF, Pipeline[PCA]]
+ """
+ temp_path = tempfile.mkdtemp()
+
+ try:
+ df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+ tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
+ pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+ p0 = Pipeline(stages=[pca])
+ pl = Pipeline(stages=[tf, p0])
+ model = pl.fit(df)
+
+ pipeline_path = temp_path + "/pipeline"
+ pl.save(pipeline_path)
+ loaded_pipeline = Pipeline.load(pipeline_path)
+ self._compare_pipelines(pl, loaded_pipeline)
+
+ model_path = temp_path + "/pipeline-model"
+ model.save(model_path)
+ loaded_model = PipelineModel.load(model_path)
+ self._compare_pipelines(model, loaded_model)
+ finally:
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
+
+ def test_python_transformer_pipeline_persistence(self):
+ """
+ Pipeline[MockUnaryTransformer, Binarizer]
+ """
+ temp_path = tempfile.mkdtemp()
+
+ try:
+ df = self.spark.range(0, 10).toDF('input')
+ tf = MockUnaryTransformer(shiftVal=2)\
+ .setInputCol("input").setOutputCol("shiftedInput")
+ tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized")
+ pl = Pipeline(stages=[tf, tf2])
+ model = pl.fit(df)
+
+ pipeline_path = temp_path + "/pipeline"
+ pl.save(pipeline_path)
+ loaded_pipeline = Pipeline.load(pipeline_path)
+ self._compare_pipelines(pl, loaded_pipeline)
+
+ model_path = temp_path + "/pipeline-model"
+ model.save(model_path)
+ loaded_model = PipelineModel.load(model_path)
+ self._compare_pipelines(model, loaded_model)
+ finally:
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
+
+ def test_onevsrest(self):
+ temp_path = tempfile.mkdtemp()
+ df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+ (1.0, Vectors.sparse(2, [], [])),
+ (2.0, Vectors.dense(0.5, 0.5))] * 10,
+ ["label", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01)
+ ovr = OneVsRest(classifier=lr)
+ model = ovr.fit(df)
+ ovrPath = temp_path + "/ovr"
+ ovr.save(ovrPath)
+ loadedOvr = OneVsRest.load(ovrPath)
+ self._compare_pipelines(ovr, loadedOvr)
+ modelPath = temp_path + "/ovrModel"
+ model.save(modelPath)
+ loadedModel = OneVsRestModel.load(modelPath)
+ self._compare_pipelines(model, loadedModel)
+
+ def test_decisiontree_classifier(self):
+ dt = DecisionTreeClassifier(maxDepth=1)
+ path = tempfile.mkdtemp()
+ dtc_path = path + "/dtc"
+ dt.save(dtc_path)
+ dt2 = DecisionTreeClassifier.load(dtc_path)
+ self.assertEqual(dt2.uid, dt2.maxDepth.parent,
+ "Loaded DecisionTreeClassifier instance uid (%s) "
+ "did not match Param's uid (%s)"
+ % (dt2.uid, dt2.maxDepth.parent))
+ self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
+ "Loaded DecisionTreeClassifier instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+ def test_decisiontree_regressor(self):
+ dt = DecisionTreeRegressor(maxDepth=1)
+ path = tempfile.mkdtemp()
+ dtr_path = path + "/dtr"
+ dt.save(dtr_path)
+ dt2 = DecisionTreeClassifier.load(dtr_path)
+ self.assertEqual(dt2.uid, dt2.maxDepth.parent,
+ "Loaded DecisionTreeRegressor instance uid (%s) "
+ "did not match Param's uid (%s)"
+ % (dt2.uid, dt2.maxDepth.parent))
+ self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
+ "Loaded DecisionTreeRegressor instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+ def test_default_read_write(self):
+ temp_path = tempfile.mkdtemp()
+
+ lr = LogisticRegression()
+ lr.setMaxIter(50)
+ lr.setThreshold(.75)
+ writer = DefaultParamsWriter(lr)
+
+ savePath = temp_path + "/lr"
+ writer.save(savePath)
+
+ reader = DefaultParamsReadable.read()
+ lr2 = reader.load(savePath)
+
+ self.assertEqual(lr.uid, lr2.uid)
+ self.assertEqual(lr.extractParamMap(), lr2.extractParamMap())
+
+ # test overwrite
+ lr.setThreshold(.8)
+ writer.overwrite().save(savePath)
+
+ reader = DefaultParamsReadable.read()
+ lr3 = reader.load(savePath)
+
+ self.assertEqual(lr.uid, lr3.uid)
+ self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
+
+ def test_default_read_write_default_params(self):
+ lr = LogisticRegression()
+ self.assertFalse(lr.isSet(lr.getParam("threshold")))
+
+ lr.setMaxIter(50)
+ lr.setThreshold(.75)
+
+ # `threshold` is set by user, default param `predictionCol` is not set by user.
+ self.assertTrue(lr.isSet(lr.getParam("threshold")))
+ self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+ self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+ writer = DefaultParamsWriter(lr)
+ metadata = json.loads(writer._get_metadata_to_save(lr, self.sc))
+ self.assertTrue("defaultParamMap" in metadata)
+
+ reader = DefaultParamsReadable.read()
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ reader.getAndSetParams(lr, loadedMetadata)
+
+ self.assertTrue(lr.isSet(lr.getParam("threshold")))
+ self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+ self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+ # manually create metadata without `defaultParamMap` section.
+ del metadata['defaultParamMap']
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
+ reader.getAndSetParams(lr, loadedMetadata)
+
+ # Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
+ metadata['sparkVersion'] = '2.3.0'
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ reader.getAndSetParams(lr, loadedMetadata)
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_persistence import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py
new file mode 100644
index 0000000000000..9e3e6c4a75d7a
--- /dev/null
+++ b/python/pyspark/ml/tests/test_pipeline.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.ml.pipeline import Pipeline
+from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer, PySparkTestCase
+
+
+class PipelineTests(PySparkTestCase):
+
+ def test_pipeline(self):
+ dataset = MockDataset()
+ estimator0 = MockEstimator()
+ transformer1 = MockTransformer()
+ estimator2 = MockEstimator()
+ transformer3 = MockTransformer()
+ pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3])
+ pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
+ model0, transformer1, model2, transformer3 = pipeline_model.stages
+ self.assertEqual(0, model0.dataset_index)
+ self.assertEqual(0, model0.getFake())
+ self.assertEqual(1, transformer1.dataset_index)
+ self.assertEqual(1, transformer1.getFake())
+ self.assertEqual(2, dataset.index)
+ self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.")
+ self.assertIsNone(transformer3.dataset_index,
+ "The last transformer shouldn't be called in fit.")
+ dataset = pipeline_model.transform(dataset)
+ self.assertEqual(2, model0.dataset_index)
+ self.assertEqual(3, transformer1.dataset_index)
+ self.assertEqual(4, model2.dataset_index)
+ self.assertEqual(5, transformer3.dataset_index)
+ self.assertEqual(6, dataset.index)
+
+ def test_identity_pipeline(self):
+ dataset = MockDataset()
+
+ def doTransform(pipeline):
+ pipeline_model = pipeline.fit(dataset)
+ return pipeline_model.transform(dataset)
+ # check that empty pipeline did not perform any transformation
+ self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index)
+ # check that failure to set stages param will raise KeyError for missing param
+ self.assertRaises(KeyError, lambda: doTransform(Pipeline()))
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_pipeline import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_stat.py b/python/pyspark/ml/tests/test_stat.py
new file mode 100644
index 0000000000000..11aaf2e8083e1
--- /dev/null
+++ b/python/pyspark/ml/tests/test_stat.py
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.stat import ChiSquareTest
+from pyspark.sql import DataFrame
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class ChiSquareTestTests(SparkSessionTestCase):
+
+ def test_chisquaretest(self):
+ data = [[0, Vectors.dense([0, 1, 2])],
+ [1, Vectors.dense([1, 1, 1])],
+ [2, Vectors.dense([2, 1, 0])]]
+ df = self.spark.createDataFrame(data, ['label', 'feat'])
+ res = ChiSquareTest.test(df, 'feat', 'label')
+ # This line is hitting the collect bug described in #17218, commented for now.
+ # pValues = res.select("degreesOfFreedom").collect())
+ self.assertIsInstance(res, DataFrame)
+ fieldNames = set(field.name for field in res.schema.fields)
+ expectedFields = ["pValues", "degreesOfFreedom", "statistics"]
+ self.assertTrue(all(field in fieldNames for field in expectedFields))
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_stat import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py
new file mode 100644
index 0000000000000..8575111c84025
--- /dev/null
+++ b/python/pyspark/ml/tests/test_training_summary.py
@@ -0,0 +1,251 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import unittest
+
+if sys.version > '3':
+ basestring = str
+
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression
+from pyspark.sql import DataFrame
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class TrainingSummaryTest(SparkSessionTestCase):
+
+ def test_linear_regression_summary(self):
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight",
+ fitIntercept=False)
+ model = lr.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertGreater(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.featuresCol, "features")
+ objHist = s.objectiveHistory
+ self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
+ self.assertAlmostEqual(s.explainedVariance, 0.25, 2)
+ self.assertAlmostEqual(s.meanAbsoluteError, 0.0)
+ self.assertAlmostEqual(s.meanSquaredError, 0.0)
+ self.assertAlmostEqual(s.rootMeanSquaredError, 0.0)
+ self.assertAlmostEqual(s.r2, 1.0, 2)
+ self.assertAlmostEqual(s.r2adj, 1.0, 2)
+ self.assertTrue(isinstance(s.residuals, DataFrame))
+ self.assertEqual(s.numInstances, 2)
+ self.assertEqual(s.degreesOfFreedom, 1)
+ devResiduals = s.devianceResiduals
+ self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float))
+ coefStdErr = s.coefficientStandardErrors
+ self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float))
+ tValues = s.tValues
+ self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float))
+ pValues = s.pValues
+ self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float))
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned
+ # The child class LinearRegressionTrainingSummary runs full test
+ sameSummary = model.evaluate(df)
+ self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance)
+
+ def test_glr_summary(self):
+ from pyspark.ml.linalg import Vectors
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight",
+ fitIntercept=False)
+ model = glr.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertEqual(s.numInstances, 2)
+ self.assertTrue(isinstance(s.residuals(), DataFrame))
+ self.assertTrue(isinstance(s.residuals("pearson"), DataFrame))
+ coefStdErr = s.coefficientStandardErrors
+ self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float))
+ tValues = s.tValues
+ self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float))
+ pValues = s.pValues
+ self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float))
+ self.assertEqual(s.degreesOfFreedom, 1)
+ self.assertEqual(s.residualDegreeOfFreedom, 1)
+ self.assertEqual(s.residualDegreeOfFreedomNull, 2)
+ self.assertEqual(s.rank, 1)
+ self.assertTrue(isinstance(s.solver, basestring))
+ self.assertTrue(isinstance(s.aic, float))
+ self.assertTrue(isinstance(s.deviance, float))
+ self.assertTrue(isinstance(s.nullDeviance, float))
+ self.assertTrue(isinstance(s.dispersion, float))
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned
+ # The child class GeneralizedLinearRegressionTrainingSummary runs full test
+ sameSummary = model.evaluate(df)
+ self.assertAlmostEqual(sameSummary.deviance, s.deviance)
+
+ def test_binary_logistic_regression_summary(self):
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
+ model = lr.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.probabilityCol, "probability")
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.featuresCol, "features")
+ self.assertEqual(s.predictionCol, "prediction")
+ objHist = s.objectiveHistory
+ self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
+ self.assertGreater(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.labels, list))
+ self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.precisionByLabel, list))
+ self.assertTrue(isinstance(s.recallByLabel, list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(), list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
+ self.assertTrue(isinstance(s.roc, DataFrame))
+ self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
+ self.assertTrue(isinstance(s.pr, DataFrame))
+ self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
+ self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
+ self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
+ self.assertAlmostEqual(s.accuracy, 1.0, 2)
+ self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2)
+ self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2)
+ self.assertAlmostEqual(s.weightedRecall, 1.0, 2)
+ self.assertAlmostEqual(s.weightedPrecision, 1.0, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2)
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned, Scala version runs full test
+ sameSummary = model.evaluate(df)
+ self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
+
+ def test_multiclass_logistic_regression_summary(self):
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], [])),
+ (2.0, 2.0, Vectors.dense(2.0)),
+ (2.0, 2.0, Vectors.dense(1.9))],
+ ["label", "weight", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
+ model = lr.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.probabilityCol, "probability")
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.featuresCol, "features")
+ self.assertEqual(s.predictionCol, "prediction")
+ objHist = s.objectiveHistory
+ self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
+ self.assertGreater(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.labels, list))
+ self.assertTrue(isinstance(s.truePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.falsePositiveRateByLabel, list))
+ self.assertTrue(isinstance(s.precisionByLabel, list))
+ self.assertTrue(isinstance(s.recallByLabel, list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(), list))
+ self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list))
+ self.assertAlmostEqual(s.accuracy, 0.75, 2)
+ self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2)
+ self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2)
+ self.assertAlmostEqual(s.weightedRecall, 0.75, 2)
+ self.assertAlmostEqual(s.weightedPrecision, 0.583, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2)
+ self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2)
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned, Scala version runs full test
+ sameSummary = model.evaluate(df)
+ self.assertAlmostEqual(sameSummary.accuracy, s.accuracy)
+
+ def test_gaussian_mixture_summary(self):
+ data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
+ (Vectors.sparse(1, [], []),)]
+ df = self.spark.createDataFrame(data, ["features"])
+ gmm = GaussianMixture(k=2)
+ model = gmm.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.probabilityCol, "probability")
+ self.assertTrue(isinstance(s.probability, DataFrame))
+ self.assertEqual(s.featuresCol, "features")
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertTrue(isinstance(s.cluster, DataFrame))
+ self.assertEqual(len(s.clusterSizes), 2)
+ self.assertEqual(s.k, 2)
+ self.assertEqual(s.numIter, 3)
+
+ def test_bisecting_kmeans_summary(self):
+ data = [(Vectors.dense(1.0),), (Vectors.dense(5.0),), (Vectors.dense(10.0),),
+ (Vectors.sparse(1, [], []),)]
+ df = self.spark.createDataFrame(data, ["features"])
+ bkm = BisectingKMeans(k=2)
+ model = bkm.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.featuresCol, "features")
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertTrue(isinstance(s.cluster, DataFrame))
+ self.assertEqual(len(s.clusterSizes), 2)
+ self.assertEqual(s.k, 2)
+ self.assertEqual(s.numIter, 20)
+
+ def test_kmeans_summary(self):
+ data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
+ (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
+ df = self.spark.createDataFrame(data, ["features"])
+ kmeans = KMeans(k=2, seed=1)
+ model = kmeans.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.featuresCol, "features")
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertTrue(isinstance(s.cluster, DataFrame))
+ self.assertEqual(len(s.clusterSizes), 2)
+ self.assertEqual(s.k, 2)
+ self.assertEqual(s.numIter, 1)
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_training_summary import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py
new file mode 100644
index 0000000000000..39bb921aaf43d
--- /dev/null
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -0,0 +1,544 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import unittest
+
+from pyspark.ml import Estimator, Model
+from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel, OneVsRest
+from pyspark.ml.evaluation import BinaryClassificationEvaluator, \
+ MulticlassClassificationEvaluator, RegressionEvaluator
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.param import Param, Params
+from pyspark.ml.tuning import CrossValidator, CrossValidatorModel, ParamGridBuilder, \
+ TrainValidationSplit, TrainValidationSplitModel
+from pyspark.sql.functions import rand
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class HasInducedError(Params):
+
+ def __init__(self):
+ super(HasInducedError, self).__init__()
+ self.inducedError = Param(self, "inducedError",
+ "Uniformly-distributed error added to feature")
+
+ def getInducedError(self):
+ return self.getOrDefault(self.inducedError)
+
+
+class InducedErrorModel(Model, HasInducedError):
+
+ def __init__(self):
+ super(InducedErrorModel, self).__init__()
+
+ def _transform(self, dataset):
+ return dataset.withColumn("prediction",
+ dataset.feature + (rand(0) * self.getInducedError()))
+
+
+class InducedErrorEstimator(Estimator, HasInducedError):
+
+ def __init__(self, inducedError=1.0):
+ super(InducedErrorEstimator, self).__init__()
+ self._set(inducedError=inducedError)
+
+ def _fit(self, dataset):
+ model = InducedErrorModel()
+ self._copyValues(model)
+ return model
+
+
+class CrossValidatorTests(SparkSessionTestCase):
+
+ def test_copy(self):
+ dataset = self.spark.createDataFrame([
+ (10, 10.0),
+ (50, 50.0),
+ (100, 100.0),
+ (500, 500.0)] * 10,
+ ["feature", "label"])
+
+ iee = InducedErrorEstimator()
+ evaluator = RegressionEvaluator(metricName="rmse")
+
+ grid = (ParamGridBuilder()
+ .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
+ .build())
+ cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
+ cvCopied = cv.copy()
+ self.assertEqual(cv.getEstimator().uid, cvCopied.getEstimator().uid)
+
+ cvModel = cv.fit(dataset)
+ cvModelCopied = cvModel.copy()
+ for index in range(len(cvModel.avgMetrics)):
+ self.assertTrue(abs(cvModel.avgMetrics[index] - cvModelCopied.avgMetrics[index])
+ < 0.0001)
+
+ def test_fit_minimize_metric(self):
+ dataset = self.spark.createDataFrame([
+ (10, 10.0),
+ (50, 50.0),
+ (100, 100.0),
+ (500, 500.0)] * 10,
+ ["feature", "label"])
+
+ iee = InducedErrorEstimator()
+ evaluator = RegressionEvaluator(metricName="rmse")
+
+ grid = (ParamGridBuilder()
+ .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
+ .build())
+ cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ bestModel = cvModel.bestModel
+ bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+
+ self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
+ "Best model should have zero induced error")
+ self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
+
+ def test_fit_maximize_metric(self):
+ dataset = self.spark.createDataFrame([
+ (10, 10.0),
+ (50, 50.0),
+ (100, 100.0),
+ (500, 500.0)] * 10,
+ ["feature", "label"])
+
+ iee = InducedErrorEstimator()
+ evaluator = RegressionEvaluator(metricName="r2")
+
+ grid = (ParamGridBuilder()
+ .addGrid(iee.inducedError, [100.0, 0.0, 10000.0])
+ .build())
+ cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ bestModel = cvModel.bestModel
+ bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+
+ self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
+ "Best model should have zero induced error")
+ self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+
+ def test_param_grid_type_coercion(self):
+ lr = LogisticRegression(maxIter=10)
+ paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build()
+ for param in paramGrid:
+ for v in param.values():
+ assert(type(v) == float)
+
+ def test_save_load_trained_model(self):
+ # This tests saving and loading the trained model only.
+ # Save/load for CrossValidator will be added later: SPARK-13786
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ lrModel = cvModel.bestModel
+
+ cvModelPath = temp_path + "/cvModel"
+ lrModel.save(cvModelPath)
+ loadedLrModel = LogisticRegressionModel.load(cvModelPath)
+ self.assertEqual(loadedLrModel.uid, lrModel.uid)
+ self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
+
+ def test_save_load_simple_estimator(self):
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+
+ # test save/load of CrossValidator
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ cvPath = temp_path + "/cv"
+ cv.save(cvPath)
+ loadedCV = CrossValidator.load(cvPath)
+ self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+ self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+ self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
+
+ # test save/load of CrossValidatorModel
+ cvModelPath = temp_path + "/cvModel"
+ cvModel.save(cvModelPath)
+ loadedModel = CrossValidatorModel.load(cvModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
+ def test_parallel_evaluation(self):
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build()
+ evaluator = BinaryClassificationEvaluator()
+
+ # test save/load of CrossValidator
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ cv.setParallelism(1)
+ cvSerialModel = cv.fit(dataset)
+ cv.setParallelism(2)
+ cvParallelModel = cv.fit(dataset)
+ self.assertEqual(cvSerialModel.avgMetrics, cvParallelModel.avgMetrics)
+
+ def test_expose_sub_models(self):
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+
+ numFolds = 3
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
+ numFolds=numFolds, collectSubModels=True)
+
+ def checkSubModels(subModels):
+ self.assertEqual(len(subModels), numFolds)
+ for i in range(numFolds):
+ self.assertEqual(len(subModels[i]), len(grid))
+
+ cvModel = cv.fit(dataset)
+ checkSubModels(cvModel.subModels)
+
+ # Test the default value for option "persistSubModel" to be "true"
+ testSubPath = temp_path + "/testCrossValidatorSubModels"
+ savingPathWithSubModels = testSubPath + "cvModel3"
+ cvModel.save(savingPathWithSubModels)
+ cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
+ checkSubModels(cvModel3.subModels)
+ cvModel4 = cvModel3.copy()
+ checkSubModels(cvModel4.subModels)
+
+ savingPathWithoutSubModels = testSubPath + "cvModel2"
+ cvModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels)
+ cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
+ self.assertEqual(cvModel2.subModels, None)
+
+ for i in range(numFolds):
+ for j in range(len(grid)):
+ self.assertEqual(cvModel.subModels[i][j].uid, cvModel3.subModels[i][j].uid)
+
+ def test_save_load_nested_estimator(self):
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+
+ ova = OneVsRest(classifier=LogisticRegression())
+ lr1 = LogisticRegression().setMaxIter(100)
+ lr2 = LogisticRegression().setMaxIter(150)
+ grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
+ evaluator = MulticlassClassificationEvaluator()
+
+ # test save/load of CrossValidator
+ cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ cvPath = temp_path + "/cv"
+ cv.save(cvPath)
+ loadedCV = CrossValidator.load(cvPath)
+ self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+ self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+
+ originalParamMap = cv.getEstimatorParamMaps()
+ loadedParamMap = loadedCV.getEstimatorParamMaps()
+ for i, param in enumerate(loadedParamMap):
+ for p in param:
+ if p.name == "classifier":
+ self.assertEqual(param[p].uid, originalParamMap[i][p].uid)
+ else:
+ self.assertEqual(param[p], originalParamMap[i][p])
+
+ # test save/load of CrossValidatorModel
+ cvModelPath = temp_path + "/cvModel"
+ cvModel.save(cvModelPath)
+ loadedModel = CrossValidatorModel.load(cvModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
+
+class TrainValidationSplitTests(SparkSessionTestCase):
+
+ def test_fit_minimize_metric(self):
+ dataset = self.spark.createDataFrame([
+ (10, 10.0),
+ (50, 50.0),
+ (100, 100.0),
+ (500, 500.0)] * 10,
+ ["feature", "label"])
+
+ iee = InducedErrorEstimator()
+ evaluator = RegressionEvaluator(metricName="rmse")
+
+ grid = ParamGridBuilder() \
+ .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \
+ .build()
+ tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ bestModel = tvsModel.bestModel
+ bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+ validationMetrics = tvsModel.validationMetrics
+
+ self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
+ "Best model should have zero induced error")
+ self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
+ self.assertEqual(len(grid), len(validationMetrics),
+ "validationMetrics has the same size of grid parameter")
+ self.assertEqual(0.0, min(validationMetrics))
+
+ def test_fit_maximize_metric(self):
+ dataset = self.spark.createDataFrame([
+ (10, 10.0),
+ (50, 50.0),
+ (100, 100.0),
+ (500, 500.0)] * 10,
+ ["feature", "label"])
+
+ iee = InducedErrorEstimator()
+ evaluator = RegressionEvaluator(metricName="r2")
+
+ grid = ParamGridBuilder() \
+ .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \
+ .build()
+ tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ bestModel = tvsModel.bestModel
+ bestModelMetric = evaluator.evaluate(bestModel.transform(dataset))
+ validationMetrics = tvsModel.validationMetrics
+
+ self.assertEqual(0.0, bestModel.getOrDefault('inducedError'),
+ "Best model should have zero induced error")
+ self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+ self.assertEqual(len(grid), len(validationMetrics),
+ "validationMetrics has the same size of grid parameter")
+ self.assertEqual(1.0, max(validationMetrics))
+
+ def test_save_load_trained_model(self):
+ # This tests saving and loading the trained model only.
+ # Save/load for TrainValidationSplit will be added later: SPARK-13786
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ lrModel = tvsModel.bestModel
+
+ tvsModelPath = temp_path + "/tvsModel"
+ lrModel.save(tvsModelPath)
+ loadedLrModel = LogisticRegressionModel.load(tvsModelPath)
+ self.assertEqual(loadedLrModel.uid, lrModel.uid)
+ self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
+
+ def test_save_load_simple_estimator(self):
+ # This tests saving and loading the trained model only.
+ # Save/load for TrainValidationSplit will be added later: SPARK-13786
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+
+ tvsPath = temp_path + "/tvs"
+ tvs.save(tvsPath)
+ loadedTvs = TrainValidationSplit.load(tvsPath)
+ self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+ self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+ self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
+
+ tvsModelPath = temp_path + "/tvsModel"
+ tvsModel.save(tvsModelPath)
+ loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
+ def test_parallel_evaluation(self):
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [5, 6]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ tvs.setParallelism(1)
+ tvsSerialModel = tvs.fit(dataset)
+ tvs.setParallelism(2)
+ tvsParallelModel = tvs.fit(dataset)
+ self.assertEqual(tvsSerialModel.validationMetrics, tvsParallelModel.validationMetrics)
+
+ def test_expose_sub_models(self):
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
+ collectSubModels=True)
+ tvsModel = tvs.fit(dataset)
+ self.assertEqual(len(tvsModel.subModels), len(grid))
+
+ # Test the default value for option "persistSubModel" to be "true"
+ testSubPath = temp_path + "/testTrainValidationSplitSubModels"
+ savingPathWithSubModels = testSubPath + "cvModel3"
+ tvsModel.save(savingPathWithSubModels)
+ tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels)
+ self.assertEqual(len(tvsModel3.subModels), len(grid))
+ tvsModel4 = tvsModel3.copy()
+ self.assertEqual(len(tvsModel4.subModels), len(grid))
+
+ savingPathWithoutSubModels = testSubPath + "cvModel2"
+ tvsModel.write().option("persistSubModels", "false").save(savingPathWithoutSubModels)
+ tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
+ self.assertEqual(tvsModel2.subModels, None)
+
+ for i in range(len(grid)):
+ self.assertEqual(tvsModel.subModels[i].uid, tvsModel3.subModels[i].uid)
+
+ def test_save_load_nested_estimator(self):
+ # This tests saving and loading the trained model only.
+ # Save/load for TrainValidationSplit will be added later: SPARK-13786
+ temp_path = tempfile.mkdtemp()
+ dataset = self.spark.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ ova = OneVsRest(classifier=LogisticRegression())
+ lr1 = LogisticRegression().setMaxIter(100)
+ lr2 = LogisticRegression().setMaxIter(150)
+ grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
+ evaluator = MulticlassClassificationEvaluator()
+
+ tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ tvsPath = temp_path + "/tvs"
+ tvs.save(tvsPath)
+ loadedTvs = TrainValidationSplit.load(tvsPath)
+ self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+ self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+
+ originalParamMap = tvs.getEstimatorParamMaps()
+ loadedParamMap = loadedTvs.getEstimatorParamMaps()
+ for i, param in enumerate(loadedParamMap):
+ for p in param:
+ if p.name == "classifier":
+ self.assertEqual(param[p].uid, originalParamMap[i][p].uid)
+ else:
+ self.assertEqual(param[p], originalParamMap[i][p])
+
+ tvsModelPath = temp_path + "/tvsModel"
+ tvsModel.save(tvsModelPath)
+ loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
+ def test_copy(self):
+ dataset = self.spark.createDataFrame([
+ (10, 10.0),
+ (50, 50.0),
+ (100, 100.0),
+ (500, 500.0)] * 10,
+ ["feature", "label"])
+
+ iee = InducedErrorEstimator()
+ evaluator = RegressionEvaluator(metricName="r2")
+
+ grid = ParamGridBuilder() \
+ .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \
+ .build()
+ tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ tvsCopied = tvs.copy()
+ tvsModelCopied = tvsModel.copy()
+
+ self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid,
+ "Copied TrainValidationSplit has the same uid of Estimator")
+
+ self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid)
+ self.assertEqual(len(tvsModel.validationMetrics),
+ len(tvsModelCopied.validationMetrics),
+ "Copied validationMetrics has the same size of the original")
+ for index in range(len(tvsModel.validationMetrics)):
+ self.assertEqual(tvsModel.validationMetrics[index],
+ tvsModelCopied.validationMetrics[index])
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_tuning import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py
new file mode 100644
index 0000000000000..ae672a00c1dc1
--- /dev/null
+++ b/python/pyspark/ml/tests/test_wrapper.py
@@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+import py4j
+
+from pyspark.ml.linalg import DenseVector, Vectors
+from pyspark.ml.regression import LinearRegression
+from pyspark.ml.wrapper import _java2py, _py2java, JavaParams, JavaWrapper
+from pyspark.testing.mllibutils import MLlibTestCase
+from pyspark.testing.mlutils import SparkSessionTestCase
+
+
+class JavaWrapperMemoryTests(SparkSessionTestCase):
+
+ def test_java_object_gets_detached(self):
+ df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight",
+ fitIntercept=False)
+
+ model = lr.fit(df)
+ summary = model.summary
+
+ self.assertIsInstance(model, JavaWrapper)
+ self.assertIsInstance(summary, JavaWrapper)
+ self.assertIsInstance(model, JavaParams)
+ self.assertNotIsInstance(summary, JavaParams)
+
+ error_no_object = 'Target Object ID does not exist for this gateway'
+
+ self.assertIn("LinearRegression_", model._java_obj.toString())
+ self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
+
+ model.__del__()
+
+ with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+ model._java_obj.toString()
+ self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
+
+ try:
+ summary.__del__()
+ except:
+ pass
+
+ with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+ model._java_obj.toString()
+ with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+ summary._java_obj.toString()
+
+
+class WrapperTests(MLlibTestCase):
+
+ def test_new_java_array(self):
+ # test array of strings
+ str_list = ["a", "b", "c"]
+ java_class = self.sc._gateway.jvm.java.lang.String
+ java_array = JavaWrapper._new_java_array(str_list, java_class)
+ self.assertEqual(_java2py(self.sc, java_array), str_list)
+ # test array of integers
+ int_list = [1, 2, 3]
+ java_class = self.sc._gateway.jvm.java.lang.Integer
+ java_array = JavaWrapper._new_java_array(int_list, java_class)
+ self.assertEqual(_java2py(self.sc, java_array), int_list)
+ # test array of floats
+ float_list = [0.1, 0.2, 0.3]
+ java_class = self.sc._gateway.jvm.java.lang.Double
+ java_array = JavaWrapper._new_java_array(float_list, java_class)
+ self.assertEqual(_java2py(self.sc, java_array), float_list)
+ # test array of bools
+ bool_list = [False, True, True]
+ java_class = self.sc._gateway.jvm.java.lang.Boolean
+ java_array = JavaWrapper._new_java_array(bool_list, java_class)
+ self.assertEqual(_java2py(self.sc, java_array), bool_list)
+ # test array of Java DenseVectors
+ v1 = DenseVector([0.0, 1.0])
+ v2 = DenseVector([1.0, 0.0])
+ vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)]
+ java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector
+ java_array = JavaWrapper._new_java_array(vec_java_list, java_class)
+ self.assertEqual(_java2py(self.sc, java_array), [v1, v2])
+ # test empty array
+ java_class = self.sc._gateway.jvm.java.lang.Integer
+ java_array = JavaWrapper._new_java_array([], java_class)
+ self.assertEqual(_java2py(self.sc, java_array), [])
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_wrapper import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index b1a8af6bcc094..4f4355ddb60ee 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -184,7 +184,7 @@ class KMeansModel(Saveable, Loader):
>>> model.k
2
>>> model.computeCost(sc.parallelize(data))
- 2.0000000000000004
+ 2.0
>>> model = KMeans.train(sc.parallelize(data), 2)
>>> sparse_data = [
... SparseVector(3, {1: 1.0}),
diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
index de18dad1f675d..6accb9b4926e8 100644
--- a/python/pyspark/mllib/fpm.py
+++ b/python/pyspark/mllib/fpm.py
@@ -132,7 +132,7 @@ class PrefixSpan(object):
A parallel PrefixSpan algorithm to mine frequent sequential patterns.
The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan:
Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth
- ([[http://doi.org/10.1109/ICDE.2001.914830]]).
+ ([[https://doi.org/10.1109/ICDE.2001.914830]]).
.. versionadded:: 1.6.0
"""
diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py
index 7e8b15056cabe..b7f09782be9dd 100644
--- a/python/pyspark/mllib/linalg/distributed.py
+++ b/python/pyspark/mllib/linalg/distributed.py
@@ -270,7 +270,7 @@ def tallSkinnyQR(self, computeQ=False):
Reference:
Paul G. Constantine, David F. Gleich. "Tall and skinny QR
factorizations in MapReduce architectures"
- ([[http://dx.doi.org/10.1145/1996092.1996103]])
+ ([[https://doi.org/10.1145/1996092.1996103]])
:param: computeQ: whether to computeQ
:return: QRDecomposition(Q: RowMatrix, R: Matrix), where
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
deleted file mode 100644
index 653f5cb9ff4a2..0000000000000
--- a/python/pyspark/mllib/tests.py
+++ /dev/null
@@ -1,1788 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-Fuller unit tests for Python MLlib.
-"""
-
-import os
-import sys
-import tempfile
-import array as pyarray
-from math import sqrt
-from time import time, sleep
-from shutil import rmtree
-
-import unishark
-from numpy import (
- array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones)
-from numpy import sum as array_sum
-
-from py4j.protocol import Py4JJavaError
-
-if sys.version > '3':
- basestring = str
-
-if sys.version_info[:2] <= (2, 6):
- try:
- import unittest2 as unittest
- except ImportError:
- sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
- sys.exit(1)
-else:
- import unittest
-
-from pyspark import SparkContext
-import pyspark.ml.linalg as newlinalg
-from pyspark.mllib.common import _to_java_object_rdd
-from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
-from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
- DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
-from pyspark.mllib.linalg.distributed import RowMatrix
-from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
-from pyspark.mllib.fpm import FPGrowth
-from pyspark.mllib.recommendation import Rating
-from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
-from pyspark.mllib.random import RandomRDDs
-from pyspark.mllib.stat import Statistics
-from pyspark.mllib.feature import HashingTF
-from pyspark.mllib.feature import Word2Vec
-from pyspark.mllib.feature import IDF
-from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
-from pyspark.mllib.util import LinearDataGenerator
-from pyspark.mllib.util import MLUtils
-from pyspark.serializers import PickleSerializer
-from pyspark.streaming import StreamingContext
-from pyspark.sql import SparkSession
-from pyspark.sql.utils import IllegalArgumentException
-from pyspark.streaming import StreamingContext
-
-_have_scipy = False
-try:
- import scipy.sparse
- _have_scipy = True
-except:
- # No SciPy, but that's okay, we'll skip those tests
- pass
-
-ser = PickleSerializer()
-
-
-class MLlibTestCase(unittest.TestCase):
- def setUp(self):
- self.sc = SparkContext('local[4]', "MLlib tests")
- self.spark = SparkSession(self.sc)
-
- def tearDown(self):
- self.spark.stop()
-
-
-class MLLibStreamingTestCase(unittest.TestCase):
- def setUp(self):
- self.sc = SparkContext('local[4]', "MLlib tests")
- self.ssc = StreamingContext(self.sc, 1.0)
-
- def tearDown(self):
- self.ssc.stop(False)
- self.sc.stop()
-
- @staticmethod
- def _eventually(condition, timeout=30.0, catch_assertions=False):
- """
- Wait a given amount of time for a condition to pass, else fail with an error.
- This is a helper utility for streaming ML tests.
- :param condition: Function that checks for termination conditions.
- condition() can return:
- - True: Conditions met. Return without error.
- - other value: Conditions not met yet. Continue. Upon timeout,
- include last such value in error message.
- Note that this method may be called at any time during
- streaming execution (e.g., even before any results
- have been created).
- :param timeout: Number of seconds to wait. Default 30 seconds.
- :param catch_assertions: If False (default), do not catch AssertionErrors.
- If True, catch AssertionErrors; continue, but save
- error to throw upon timeout.
- """
- start_time = time()
- lastValue = None
- while time() - start_time < timeout:
- if catch_assertions:
- try:
- lastValue = condition()
- except AssertionError as e:
- lastValue = e
- else:
- lastValue = condition()
- if lastValue is True:
- return
- sleep(0.01)
- if isinstance(lastValue, AssertionError):
- raise lastValue
- else:
- raise AssertionError(
- "Test failed due to timeout after %g sec, with last condition returning: %s"
- % (timeout, lastValue))
-
-
-def _squared_distance(a, b):
- if isinstance(a, Vector):
- return a.squared_distance(b)
- else:
- return b.squared_distance(a)
-
-
-class VectorTests(MLlibTestCase):
-
- def _test_serialize(self, v):
- self.assertEqual(v, ser.loads(ser.dumps(v)))
- jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v)))
- nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec)))
- self.assertEqual(v, nv)
- vs = [v] * 100
- jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs)))
- nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs)))
- self.assertEqual(vs, nvs)
-
- def test_serialize(self):
- self._test_serialize(DenseVector(range(10)))
- self._test_serialize(DenseVector(array([1., 2., 3., 4.])))
- self._test_serialize(DenseVector(pyarray.array('d', range(10))))
- self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
- self._test_serialize(SparseVector(3, {}))
- self._test_serialize(DenseMatrix(2, 3, range(6)))
- sm1 = SparseMatrix(
- 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
- self._test_serialize(sm1)
-
- def test_dot(self):
- sv = SparseVector(4, {1: 1, 3: 2})
- dv = DenseVector(array([1., 2., 3., 4.]))
- lst = DenseVector([1, 2, 3, 4])
- mat = array([[1., 2., 3., 4.],
- [1., 2., 3., 4.],
- [1., 2., 3., 4.],
- [1., 2., 3., 4.]])
- arr = pyarray.array('d', [0, 1, 2, 3])
- self.assertEqual(10.0, sv.dot(dv))
- self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat)))
- self.assertEqual(30.0, dv.dot(dv))
- self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat)))
- self.assertEqual(30.0, lst.dot(dv))
- self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat)))
- self.assertEqual(7.0, sv.dot(arr))
-
- def test_squared_distance(self):
- sv = SparseVector(4, {1: 1, 3: 2})
- dv = DenseVector(array([1., 2., 3., 4.]))
- lst = DenseVector([4, 3, 2, 1])
- lst1 = [4, 3, 2, 1]
- arr = pyarray.array('d', [0, 2, 1, 3])
- narr = array([0, 2, 1, 3])
- self.assertEqual(15.0, _squared_distance(sv, dv))
- self.assertEqual(25.0, _squared_distance(sv, lst))
- self.assertEqual(20.0, _squared_distance(dv, lst))
- self.assertEqual(15.0, _squared_distance(dv, sv))
- self.assertEqual(25.0, _squared_distance(lst, sv))
- self.assertEqual(20.0, _squared_distance(lst, dv))
- self.assertEqual(0.0, _squared_distance(sv, sv))
- self.assertEqual(0.0, _squared_distance(dv, dv))
- self.assertEqual(0.0, _squared_distance(lst, lst))
- self.assertEqual(25.0, _squared_distance(sv, lst1))
- self.assertEqual(3.0, _squared_distance(sv, arr))
- self.assertEqual(3.0, _squared_distance(sv, narr))
-
- def test_hash(self):
- v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
- v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
- v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
- v4 = SparseVector(4, [(1, 1.0), (3, 2.5)])
- self.assertEqual(hash(v1), hash(v2))
- self.assertEqual(hash(v1), hash(v3))
- self.assertEqual(hash(v2), hash(v3))
- self.assertFalse(hash(v1) == hash(v4))
- self.assertFalse(hash(v2) == hash(v4))
-
- def test_eq(self):
- v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
- v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
- v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
- v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
- v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
- v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
- self.assertEqual(v1, v2)
- self.assertEqual(v1, v3)
- self.assertFalse(v2 == v4)
- self.assertFalse(v1 == v5)
- self.assertFalse(v1 == v6)
-
- def test_equals(self):
- indices = [1, 2, 4]
- values = [1., 3., 2.]
- self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.]))
- self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.]))
- self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.]))
- self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.]))
-
- def test_conversion(self):
- # numpy arrays should be automatically upcast to float64
- # tests for fix of [SPARK-5089]
- v = array([1, 2, 3, 4], dtype='float64')
- dv = DenseVector(v)
- self.assertTrue(dv.array.dtype == 'float64')
- v = array([1, 2, 3, 4], dtype='float32')
- dv = DenseVector(v)
- self.assertTrue(dv.array.dtype == 'float64')
-
- def test_sparse_vector_indexing(self):
- sv = SparseVector(5, {1: 1, 3: 2})
- self.assertEqual(sv[0], 0.)
- self.assertEqual(sv[3], 2.)
- self.assertEqual(sv[1], 1.)
- self.assertEqual(sv[2], 0.)
- self.assertEqual(sv[4], 0.)
- self.assertEqual(sv[-1], 0.)
- self.assertEqual(sv[-2], 2.)
- self.assertEqual(sv[-3], 0.)
- self.assertEqual(sv[-5], 0.)
- for ind in [5, -6]:
- self.assertRaises(IndexError, sv.__getitem__, ind)
- for ind in [7.8, '1']:
- self.assertRaises(TypeError, sv.__getitem__, ind)
-
- zeros = SparseVector(4, {})
- self.assertEqual(zeros[0], 0.0)
- self.assertEqual(zeros[3], 0.0)
- for ind in [4, -5]:
- self.assertRaises(IndexError, zeros.__getitem__, ind)
-
- empty = SparseVector(0, {})
- for ind in [-1, 0, 1]:
- self.assertRaises(IndexError, empty.__getitem__, ind)
-
- def test_sparse_vector_iteration(self):
- self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0])
- self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0])
-
- def test_matrix_indexing(self):
- mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
- expected = [[0, 6], [1, 8], [4, 10]]
- for i in range(3):
- for j in range(2):
- self.assertEqual(mat[i, j], expected[i][j])
-
- for i, j in [(-1, 0), (4, 1), (3, 4)]:
- self.assertRaises(IndexError, mat.__getitem__, (i, j))
-
- def test_repr_dense_matrix(self):
- mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
- self.assertTrue(
- repr(mat),
- 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
-
- mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True)
- self.assertTrue(
- repr(mat),
- 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
-
- mat = DenseMatrix(6, 3, zeros(18))
- self.assertTrue(
- repr(mat),
- 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \
- 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)')
-
- def test_repr_sparse_matrix(self):
- sm1t = SparseMatrix(
- 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
- isTransposed=True)
- self.assertTrue(
- repr(sm1t),
- 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)')
-
- indices = tile(arange(6), 3)
- values = ones(18)
- sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values)
- self.assertTrue(
- repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \
- [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \
- [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \
- 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)")
-
- self.assertTrue(
- str(sm),
- "6 X 3 CSCMatrix\n\
- (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\
- (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\
- (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..")
-
- sm = SparseMatrix(1, 18, zeros(19), [], [])
- self.assertTrue(
- repr(sm),
- 'SparseMatrix(1, 18, \
- [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)')
-
- def test_sparse_matrix(self):
- # Test sparse matrix creation.
- sm1 = SparseMatrix(
- 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
- self.assertEqual(sm1.numRows, 3)
- self.assertEqual(sm1.numCols, 4)
- self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4])
- self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2])
- self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0])
- self.assertTrue(
- repr(sm1),
- 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)')
-
- # Test indexing
- expected = [
- [0, 0, 0, 0],
- [1, 0, 4, 0],
- [2, 0, 5, 0]]
-
- for i in range(3):
- for j in range(4):
- self.assertEqual(expected[i][j], sm1[i, j])
- self.assertTrue(array_equal(sm1.toArray(), expected))
-
- for i, j in [(-1, 1), (4, 3), (3, 5)]:
- self.assertRaises(IndexError, sm1.__getitem__, (i, j))
-
- # Test conversion to dense and sparse.
- smnew = sm1.toDense().toSparse()
- self.assertEqual(sm1.numRows, smnew.numRows)
- self.assertEqual(sm1.numCols, smnew.numCols)
- self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs))
- self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices))
- self.assertTrue(array_equal(sm1.values, smnew.values))
-
- sm1t = SparseMatrix(
- 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
- isTransposed=True)
- self.assertEqual(sm1t.numRows, 3)
- self.assertEqual(sm1t.numCols, 4)
- self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5])
- self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2])
- self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0])
-
- expected = [
- [3, 2, 0, 0],
- [0, 0, 4, 0],
- [9, 0, 8, 0]]
-
- for i in range(3):
- for j in range(4):
- self.assertEqual(expected[i][j], sm1t[i, j])
- self.assertTrue(array_equal(sm1t.toArray(), expected))
-
- def test_dense_matrix_is_transposed(self):
- mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True)
- mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9])
- self.assertEqual(mat1, mat)
-
- expected = [[0, 4], [1, 6], [3, 9]]
- for i in range(3):
- for j in range(2):
- self.assertEqual(mat1[i, j], expected[i][j])
- self.assertTrue(array_equal(mat1.toArray(), expected))
-
- sm = mat1.toSparse()
- self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2]))
- self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
- self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
-
- def test_parse_vector(self):
- a = DenseVector([])
- self.assertEqual(str(a), '[]')
- self.assertEqual(Vectors.parse(str(a)), a)
- a = DenseVector([3, 4, 6, 7])
- self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]')
- self.assertEqual(Vectors.parse(str(a)), a)
- a = SparseVector(4, [], [])
- self.assertEqual(str(a), '(4,[],[])')
- self.assertEqual(SparseVector.parse(str(a)), a)
- a = SparseVector(4, [0, 2], [3, 4])
- self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])')
- self.assertEqual(Vectors.parse(str(a)), a)
- a = SparseVector(10, [0, 1], [4, 5])
- self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
-
- def test_norms(self):
- a = DenseVector([0, 2, 3, -1])
- self.assertAlmostEqual(a.norm(2), 3.742, 3)
- self.assertTrue(a.norm(1), 6)
- self.assertTrue(a.norm(inf), 3)
- a = SparseVector(4, [0, 2], [3, -4])
- self.assertAlmostEqual(a.norm(2), 5)
- self.assertTrue(a.norm(1), 7)
- self.assertTrue(a.norm(inf), 4)
-
- tmp = SparseVector(4, [0, 2], [3, 0])
- self.assertEqual(tmp.numNonzeros(), 1)
-
- def test_ml_mllib_vector_conversion(self):
- # to ml
- # dense
- mllibDV = Vectors.dense([1, 2, 3])
- mlDV1 = newlinalg.Vectors.dense([1, 2, 3])
- mlDV2 = mllibDV.asML()
- self.assertEqual(mlDV2, mlDV1)
- # sparse
- mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5})
- mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5})
- mlSV2 = mllibSV.asML()
- self.assertEqual(mlSV2, mlSV1)
- # from ml
- # dense
- mllibDV1 = Vectors.dense([1, 2, 3])
- mlDV = newlinalg.Vectors.dense([1, 2, 3])
- mllibDV2 = Vectors.fromML(mlDV)
- self.assertEqual(mllibDV1, mllibDV2)
- # sparse
- mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5})
- mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5})
- mllibSV2 = Vectors.fromML(mlSV)
- self.assertEqual(mllibSV1, mllibSV2)
-
- def test_ml_mllib_matrix_conversion(self):
- # to ml
- # dense
- mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3])
- mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3])
- mlDM2 = mllibDM.asML()
- self.assertEqual(mlDM2, mlDM1)
- # transposed
- mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True)
- mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True)
- mlDMt2 = mllibDMt.asML()
- self.assertEqual(mlDMt2, mlDMt1)
- # sparse
- mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
- mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
- mlSM2 = mllibSM.asML()
- self.assertEqual(mlSM2, mlSM1)
- # transposed
- mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
- mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
- mlSMt2 = mllibSMt.asML()
- self.assertEqual(mlSMt2, mlSMt1)
- # from ml
- # dense
- mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4])
- mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4])
- mllibDM2 = Matrices.fromML(mlDM)
- self.assertEqual(mllibDM1, mllibDM2)
- # transposed
- mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True)
- mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True)
- mllibDMt2 = Matrices.fromML(mlDMt)
- self.assertEqual(mllibDMt1, mllibDMt2)
- # sparse
- mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
- mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
- mllibSM2 = Matrices.fromML(mlSM)
- self.assertEqual(mllibSM1, mllibSM2)
- # transposed
- mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
- mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
- mllibSMt2 = Matrices.fromML(mlSMt)
- self.assertEqual(mllibSMt1, mllibSMt2)
-
-
-class ListTests(MLlibTestCase):
-
- """
- Test MLlib algorithms on plain lists, to make sure they're passed through
- as NumPy arrays.
- """
-
- def test_bisecting_kmeans(self):
- from pyspark.mllib.clustering import BisectingKMeans
- data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2)
- bskm = BisectingKMeans()
- model = bskm.train(self.sc.parallelize(data, 2), k=4)
- p = array([0.0, 0.0])
- rdd_p = self.sc.parallelize([p])
- self.assertEqual(model.predict(p), model.predict(rdd_p).first())
- self.assertEqual(model.computeCost(p), model.computeCost(rdd_p))
- self.assertEqual(model.k, len(model.clusterCenters))
-
- def test_kmeans(self):
- from pyspark.mllib.clustering import KMeans
- data = [
- [0, 1.1],
- [0, 1.2],
- [1.1, 0],
- [1.2, 0],
- ]
- clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||",
- initializationSteps=7, epsilon=1e-4)
- self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1]))
- self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3]))
-
- def test_kmeans_deterministic(self):
- from pyspark.mllib.clustering import KMeans
- X = range(0, 100, 10)
- Y = range(0, 100, 10)
- data = [[x, y] for x, y in zip(X, Y)]
- clusters1 = KMeans.train(self.sc.parallelize(data),
- 3, initializationMode="k-means||",
- seed=42, initializationSteps=7, epsilon=1e-4)
- clusters2 = KMeans.train(self.sc.parallelize(data),
- 3, initializationMode="k-means||",
- seed=42, initializationSteps=7, epsilon=1e-4)
- centers1 = clusters1.centers
- centers2 = clusters2.centers
- for c1, c2 in zip(centers1, centers2):
- # TODO: Allow small numeric difference.
- self.assertTrue(array_equal(c1, c2))
-
- def test_gmm(self):
- from pyspark.mllib.clustering import GaussianMixture
- data = self.sc.parallelize([
- [1, 2],
- [8, 9],
- [-4, -3],
- [-6, -7],
- ])
- clusters = GaussianMixture.train(data, 2, convergenceTol=0.001,
- maxIterations=10, seed=1)
- labels = clusters.predict(data).collect()
- self.assertEqual(labels[0], labels[1])
- self.assertEqual(labels[2], labels[3])
-
- def test_gmm_deterministic(self):
- from pyspark.mllib.clustering import GaussianMixture
- x = range(0, 100, 10)
- y = range(0, 100, 10)
- data = self.sc.parallelize([[a, b] for a, b in zip(x, y)])
- clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001,
- maxIterations=10, seed=63)
- clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001,
- maxIterations=10, seed=63)
- for c1, c2 in zip(clusters1.weights, clusters2.weights):
- self.assertEqual(round(c1, 7), round(c2, 7))
-
- def test_gmm_with_initial_model(self):
- from pyspark.mllib.clustering import GaussianMixture
- data = self.sc.parallelize([
- (-10, -5), (-9, -4), (10, 5), (9, 4)
- ])
-
- gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
- maxIterations=10, seed=63)
- gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
- maxIterations=10, seed=63, initialModel=gmm1)
- self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)
-
- def test_classification(self):
- from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
- from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\
- RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel
- data = [
- LabeledPoint(0.0, [1, 0, 0]),
- LabeledPoint(1.0, [0, 1, 1]),
- LabeledPoint(0.0, [2, 0, 0]),
- LabeledPoint(1.0, [0, 2, 1])
- ]
- rdd = self.sc.parallelize(data)
- features = [p.features.tolist() for p in data]
-
- temp_dir = tempfile.mkdtemp()
-
- lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10)
- self.assertTrue(lr_model.predict(features[0]) <= 0)
- self.assertTrue(lr_model.predict(features[1]) > 0)
- self.assertTrue(lr_model.predict(features[2]) <= 0)
- self.assertTrue(lr_model.predict(features[3]) > 0)
-
- svm_model = SVMWithSGD.train(rdd, iterations=10)
- self.assertTrue(svm_model.predict(features[0]) <= 0)
- self.assertTrue(svm_model.predict(features[1]) > 0)
- self.assertTrue(svm_model.predict(features[2]) <= 0)
- self.assertTrue(svm_model.predict(features[3]) > 0)
-
- nb_model = NaiveBayes.train(rdd)
- self.assertTrue(nb_model.predict(features[0]) <= 0)
- self.assertTrue(nb_model.predict(features[1]) > 0)
- self.assertTrue(nb_model.predict(features[2]) <= 0)
- self.assertTrue(nb_model.predict(features[3]) > 0)
-
- categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
- dt_model = DecisionTree.trainClassifier(
- rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4)
- self.assertTrue(dt_model.predict(features[0]) <= 0)
- self.assertTrue(dt_model.predict(features[1]) > 0)
- self.assertTrue(dt_model.predict(features[2]) <= 0)
- self.assertTrue(dt_model.predict(features[3]) > 0)
-
- dt_model_dir = os.path.join(temp_dir, "dt")
- dt_model.save(self.sc, dt_model_dir)
- same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir)
- self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString())
-
- rf_model = RandomForest.trainClassifier(
- rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10,
- maxBins=4, seed=1)
- self.assertTrue(rf_model.predict(features[0]) <= 0)
- self.assertTrue(rf_model.predict(features[1]) > 0)
- self.assertTrue(rf_model.predict(features[2]) <= 0)
- self.assertTrue(rf_model.predict(features[3]) > 0)
-
- rf_model_dir = os.path.join(temp_dir, "rf")
- rf_model.save(self.sc, rf_model_dir)
- same_rf_model = RandomForestModel.load(self.sc, rf_model_dir)
- self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString())
-
- gbt_model = GradientBoostedTrees.trainClassifier(
- rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4)
- self.assertTrue(gbt_model.predict(features[0]) <= 0)
- self.assertTrue(gbt_model.predict(features[1]) > 0)
- self.assertTrue(gbt_model.predict(features[2]) <= 0)
- self.assertTrue(gbt_model.predict(features[3]) > 0)
-
- gbt_model_dir = os.path.join(temp_dir, "gbt")
- gbt_model.save(self.sc, gbt_model_dir)
- same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir)
- self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString())
-
- try:
- rmtree(temp_dir)
- except OSError:
- pass
-
- def test_regression(self):
- from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
- RidgeRegressionWithSGD
- from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
- data = [
- LabeledPoint(-1.0, [0, -1]),
- LabeledPoint(1.0, [0, 1]),
- LabeledPoint(-1.0, [0, -2]),
- LabeledPoint(1.0, [0, 2])
- ]
- rdd = self.sc.parallelize(data)
- features = [p.features.tolist() for p in data]
-
- lr_model = LinearRegressionWithSGD.train(rdd, iterations=10)
- self.assertTrue(lr_model.predict(features[0]) <= 0)
- self.assertTrue(lr_model.predict(features[1]) > 0)
- self.assertTrue(lr_model.predict(features[2]) <= 0)
- self.assertTrue(lr_model.predict(features[3]) > 0)
-
- lasso_model = LassoWithSGD.train(rdd, iterations=10)
- self.assertTrue(lasso_model.predict(features[0]) <= 0)
- self.assertTrue(lasso_model.predict(features[1]) > 0)
- self.assertTrue(lasso_model.predict(features[2]) <= 0)
- self.assertTrue(lasso_model.predict(features[3]) > 0)
-
- rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10)
- self.assertTrue(rr_model.predict(features[0]) <= 0)
- self.assertTrue(rr_model.predict(features[1]) > 0)
- self.assertTrue(rr_model.predict(features[2]) <= 0)
- self.assertTrue(rr_model.predict(features[3]) > 0)
-
- categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
- dt_model = DecisionTree.trainRegressor(
- rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4)
- self.assertTrue(dt_model.predict(features[0]) <= 0)
- self.assertTrue(dt_model.predict(features[1]) > 0)
- self.assertTrue(dt_model.predict(features[2]) <= 0)
- self.assertTrue(dt_model.predict(features[3]) > 0)
-
- rf_model = RandomForest.trainRegressor(
- rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1)
- self.assertTrue(rf_model.predict(features[0]) <= 0)
- self.assertTrue(rf_model.predict(features[1]) > 0)
- self.assertTrue(rf_model.predict(features[2]) <= 0)
- self.assertTrue(rf_model.predict(features[3]) > 0)
-
- gbt_model = GradientBoostedTrees.trainRegressor(
- rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4)
- self.assertTrue(gbt_model.predict(features[0]) <= 0)
- self.assertTrue(gbt_model.predict(features[1]) > 0)
- self.assertTrue(gbt_model.predict(features[2]) <= 0)
- self.assertTrue(gbt_model.predict(features[3]) > 0)
-
- try:
- LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10)
- LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10)
- RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10)
- except ValueError:
- self.fail()
-
- # Verify that maxBins is being passed through
- GradientBoostedTrees.trainRegressor(
- rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32)
- with self.assertRaises(Exception) as cm:
- GradientBoostedTrees.trainRegressor(
- rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1)
-
-
-class StatTests(MLlibTestCase):
- # SPARK-4023
- def test_col_with_different_rdds(self):
- # numpy
- data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
- summary = Statistics.colStats(data)
- self.assertEqual(1000, summary.count())
- # array
- data = self.sc.parallelize([range(10)] * 10)
- summary = Statistics.colStats(data)
- self.assertEqual(10, summary.count())
- # array
- data = self.sc.parallelize([pyarray.array("d", range(10))] * 10)
- summary = Statistics.colStats(data)
- self.assertEqual(10, summary.count())
-
- def test_col_norms(self):
- data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
- summary = Statistics.colStats(data)
- self.assertEqual(10, len(summary.normL1()))
- self.assertEqual(10, len(summary.normL2()))
-
- data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x))
- summary2 = Statistics.colStats(data2)
- self.assertEqual(array([45.0]), summary2.normL1())
- import math
- expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10))))
- self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14)
-
-
-class VectorUDTTests(MLlibTestCase):
-
- dv0 = DenseVector([])
- dv1 = DenseVector([1.0, 2.0])
- sv0 = SparseVector(2, [], [])
- sv1 = SparseVector(2, [1], [2.0])
- udt = VectorUDT()
-
- def test_json_schema(self):
- self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt)
-
- def test_serialization(self):
- for v in [self.dv0, self.dv1, self.sv0, self.sv1]:
- self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
-
- def test_infer_schema(self):
- rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
- df = rdd.toDF()
- schema = df.schema
- field = [f for f in schema.fields if f.name == "features"][0]
- self.assertEqual(field.dataType, self.udt)
- vectors = df.rdd.map(lambda p: p.features).collect()
- self.assertEqual(len(vectors), 2)
- for v in vectors:
- if isinstance(v, SparseVector):
- self.assertEqual(v, self.sv1)
- elif isinstance(v, DenseVector):
- self.assertEqual(v, self.dv1)
- else:
- raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
-
-
-class MatrixUDTTests(MLlibTestCase):
-
- dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
- dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True)
- sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0])
- sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True)
- udt = MatrixUDT()
-
- def test_json_schema(self):
- self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt)
-
- def test_serialization(self):
- for m in [self.dm1, self.dm2, self.sm1, self.sm2]:
- self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
-
- def test_infer_schema(self):
- rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
- df = rdd.toDF()
- schema = df.schema
- self.assertTrue(schema.fields[1].dataType, self.udt)
- matrices = df.rdd.map(lambda x: x._2).collect()
- self.assertEqual(len(matrices), 2)
- for m in matrices:
- if isinstance(m, DenseMatrix):
- self.assertTrue(m, self.dm1)
- elif isinstance(m, SparseMatrix):
- self.assertTrue(m, self.sm1)
- else:
- raise ValueError("Expected a matrix but got type %r" % type(m))
-
-
-@unittest.skipIf(not _have_scipy, "SciPy not installed")
-class SciPyTests(MLlibTestCase):
-
- """
- Test both vector operations and MLlib algorithms with SciPy sparse matrices,
- if SciPy is available.
- """
-
- def test_serialize(self):
- from scipy.sparse import lil_matrix
- lil = lil_matrix((4, 1))
- lil[1, 0] = 1
- lil[3, 0] = 2
- sv = SparseVector(4, {1: 1, 3: 2})
- self.assertEqual(sv, _convert_to_vector(lil))
- self.assertEqual(sv, _convert_to_vector(lil.tocsc()))
- self.assertEqual(sv, _convert_to_vector(lil.tocoo()))
- self.assertEqual(sv, _convert_to_vector(lil.tocsr()))
- self.assertEqual(sv, _convert_to_vector(lil.todok()))
-
- def serialize(l):
- return ser.loads(ser.dumps(_convert_to_vector(l)))
- self.assertEqual(sv, serialize(lil))
- self.assertEqual(sv, serialize(lil.tocsc()))
- self.assertEqual(sv, serialize(lil.tocsr()))
- self.assertEqual(sv, serialize(lil.todok()))
-
- def test_convert_to_vector(self):
- from scipy.sparse import csc_matrix
- # Create a CSC matrix with non-sorted indices
- indptr = array([0, 2])
- indices = array([3, 1])
- data = array([2.0, 1.0])
- csc = csc_matrix((data, indices, indptr))
- self.assertFalse(csc.has_sorted_indices)
- sv = SparseVector(4, {1: 1, 3: 2})
- self.assertEqual(sv, _convert_to_vector(csc))
-
- def test_dot(self):
- from scipy.sparse import lil_matrix
- lil = lil_matrix((4, 1))
- lil[1, 0] = 1
- lil[3, 0] = 2
- dv = DenseVector(array([1., 2., 3., 4.]))
- self.assertEqual(10.0, dv.dot(lil))
-
- def test_squared_distance(self):
- from scipy.sparse import lil_matrix
- lil = lil_matrix((4, 1))
- lil[1, 0] = 3
- lil[3, 0] = 2
- dv = DenseVector(array([1., 2., 3., 4.]))
- sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4})
- self.assertEqual(15.0, dv.squared_distance(lil))
- self.assertEqual(15.0, sv.squared_distance(lil))
-
- def scipy_matrix(self, size, values):
- """Create a column SciPy matrix from a dictionary of values"""
- from scipy.sparse import lil_matrix
- lil = lil_matrix((size, 1))
- for key, value in values.items():
- lil[key, 0] = value
- return lil
-
- def test_clustering(self):
- from pyspark.mllib.clustering import KMeans
- data = [
- self.scipy_matrix(3, {1: 1.0}),
- self.scipy_matrix(3, {1: 1.1}),
- self.scipy_matrix(3, {2: 1.0}),
- self.scipy_matrix(3, {2: 1.1})
- ]
- clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||")
- self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1]))
- self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3]))
-
- def test_classification(self):
- from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
- from pyspark.mllib.tree import DecisionTree
- data = [
- LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})),
- LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})),
- LabeledPoint(0.0, self.scipy_matrix(2, {0: 2.0})),
- LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0}))
- ]
- rdd = self.sc.parallelize(data)
- features = [p.features for p in data]
-
- lr_model = LogisticRegressionWithSGD.train(rdd)
- self.assertTrue(lr_model.predict(features[0]) <= 0)
- self.assertTrue(lr_model.predict(features[1]) > 0)
- self.assertTrue(lr_model.predict(features[2]) <= 0)
- self.assertTrue(lr_model.predict(features[3]) > 0)
-
- svm_model = SVMWithSGD.train(rdd)
- self.assertTrue(svm_model.predict(features[0]) <= 0)
- self.assertTrue(svm_model.predict(features[1]) > 0)
- self.assertTrue(svm_model.predict(features[2]) <= 0)
- self.assertTrue(svm_model.predict(features[3]) > 0)
-
- nb_model = NaiveBayes.train(rdd)
- self.assertTrue(nb_model.predict(features[0]) <= 0)
- self.assertTrue(nb_model.predict(features[1]) > 0)
- self.assertTrue(nb_model.predict(features[2]) <= 0)
- self.assertTrue(nb_model.predict(features[3]) > 0)
-
- categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
- dt_model = DecisionTree.trainClassifier(rdd, numClasses=2,
- categoricalFeaturesInfo=categoricalFeaturesInfo)
- self.assertTrue(dt_model.predict(features[0]) <= 0)
- self.assertTrue(dt_model.predict(features[1]) > 0)
- self.assertTrue(dt_model.predict(features[2]) <= 0)
- self.assertTrue(dt_model.predict(features[3]) > 0)
-
- def test_regression(self):
- from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
- RidgeRegressionWithSGD
- from pyspark.mllib.tree import DecisionTree
- data = [
- LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})),
- LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})),
- LabeledPoint(-1.0, self.scipy_matrix(2, {1: -2.0})),
- LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0}))
- ]
- rdd = self.sc.parallelize(data)
- features = [p.features for p in data]
-
- lr_model = LinearRegressionWithSGD.train(rdd)
- self.assertTrue(lr_model.predict(features[0]) <= 0)
- self.assertTrue(lr_model.predict(features[1]) > 0)
- self.assertTrue(lr_model.predict(features[2]) <= 0)
- self.assertTrue(lr_model.predict(features[3]) > 0)
-
- lasso_model = LassoWithSGD.train(rdd)
- self.assertTrue(lasso_model.predict(features[0]) <= 0)
- self.assertTrue(lasso_model.predict(features[1]) > 0)
- self.assertTrue(lasso_model.predict(features[2]) <= 0)
- self.assertTrue(lasso_model.predict(features[3]) > 0)
-
- rr_model = RidgeRegressionWithSGD.train(rdd)
- self.assertTrue(rr_model.predict(features[0]) <= 0)
- self.assertTrue(rr_model.predict(features[1]) > 0)
- self.assertTrue(rr_model.predict(features[2]) <= 0)
- self.assertTrue(rr_model.predict(features[3]) > 0)
-
- categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
- dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
- self.assertTrue(dt_model.predict(features[0]) <= 0)
- self.assertTrue(dt_model.predict(features[1]) > 0)
- self.assertTrue(dt_model.predict(features[2]) <= 0)
- self.assertTrue(dt_model.predict(features[3]) > 0)
-
-
-class ChiSqTestTests(MLlibTestCase):
- def test_goodness_of_fit(self):
- from numpy import inf
-
- observed = Vectors.dense([4, 6, 5])
- pearson = Statistics.chiSqTest(observed)
-
- # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))`
- self.assertEqual(pearson.statistic, 0.4)
- self.assertEqual(pearson.degreesOfFreedom, 2)
- self.assertAlmostEqual(pearson.pValue, 0.8187, 4)
-
- # Different expected and observed sum
- observed1 = Vectors.dense([21, 38, 43, 80])
- expected1 = Vectors.dense([3, 5, 7, 20])
- pearson1 = Statistics.chiSqTest(observed1, expected1)
-
- # Results validated against the R command
- # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))`
- self.assertAlmostEqual(pearson1.statistic, 14.1429, 4)
- self.assertEqual(pearson1.degreesOfFreedom, 3)
- self.assertAlmostEqual(pearson1.pValue, 0.002717, 4)
-
- # Vectors with different sizes
- observed3 = Vectors.dense([1.0, 2.0, 3.0])
- expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0])
- self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3)
-
- # Negative counts in observed
- neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0])
- self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1)
-
- # Count = 0.0 in expected but not observed
- zero_expected = Vectors.dense([1.0, 0.0, 3.0])
- pearson_inf = Statistics.chiSqTest(observed, zero_expected)
- self.assertEqual(pearson_inf.statistic, inf)
- self.assertEqual(pearson_inf.degreesOfFreedom, 2)
- self.assertEqual(pearson_inf.pValue, 0.0)
-
- # 0.0 in expected and observed simultaneously
- zero_observed = Vectors.dense([2.0, 0.0, 1.0])
- self.assertRaises(
- IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected)
-
- def test_matrix_independence(self):
- data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0]
- chi = Statistics.chiSqTest(Matrices.dense(3, 4, data))
-
- # Results validated against R command
- # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))`
- self.assertAlmostEqual(chi.statistic, 21.9958, 4)
- self.assertEqual(chi.degreesOfFreedom, 6)
- self.assertAlmostEqual(chi.pValue, 0.001213, 4)
-
- # Negative counts
- neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0])
- self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts)
-
- # Row sum = 0.0
- row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0])
- self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero)
-
- # Column sum = 0.0
- col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0])
- self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero)
-
- def test_chi_sq_pearson(self):
- data = [
- LabeledPoint(0.0, Vectors.dense([0.5, 10.0])),
- LabeledPoint(0.0, Vectors.dense([1.5, 20.0])),
- LabeledPoint(1.0, Vectors.dense([1.5, 30.0])),
- LabeledPoint(0.0, Vectors.dense([3.5, 30.0])),
- LabeledPoint(0.0, Vectors.dense([3.5, 40.0])),
- LabeledPoint(1.0, Vectors.dense([3.5, 40.0]))
- ]
-
- for numParts in [2, 4, 6, 8]:
- chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts))
- feature1 = chi[0]
- self.assertEqual(feature1.statistic, 0.75)
- self.assertEqual(feature1.degreesOfFreedom, 2)
- self.assertAlmostEqual(feature1.pValue, 0.6873, 4)
-
- feature2 = chi[1]
- self.assertEqual(feature2.statistic, 1.5)
- self.assertEqual(feature2.degreesOfFreedom, 3)
- self.assertAlmostEqual(feature2.pValue, 0.6823, 4)
-
- def test_right_number_of_results(self):
- num_cols = 1001
- sparse_data = [
- LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])),
- LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)]))
- ]
- chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data))
- self.assertEqual(len(chi), num_cols)
- self.assertIsNotNone(chi[1000])
-
-
-class KolmogorovSmirnovTest(MLlibTestCase):
-
- def test_R_implementation_equivalence(self):
- data = self.sc.parallelize([
- 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501,
- -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555,
- -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063,
- -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691,
- 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942
- ])
- model = Statistics.kolmogorovSmirnovTest(data, "norm")
- self.assertAlmostEqual(model.statistic, 0.189, 3)
- self.assertAlmostEqual(model.pValue, 0.422, 3)
-
- model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1)
- self.assertAlmostEqual(model.statistic, 0.189, 3)
- self.assertAlmostEqual(model.pValue, 0.422, 3)
-
-
-class SerDeTest(MLlibTestCase):
- def test_to_java_object_rdd(self): # SPARK-6660
- data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0)
- self.assertEqual(_to_java_object_rdd(data).count(), 10)
-
-
-class FeatureTest(MLlibTestCase):
- def test_idf_model(self):
- data = [
- Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]),
- Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]),
- Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]),
- Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9])
- ]
- model = IDF().fit(self.sc.parallelize(data, 2))
- idf = model.idf()
- self.assertEqual(len(idf), 11)
-
-
-class Word2VecTests(MLlibTestCase):
- def test_word2vec_setters(self):
- model = Word2Vec() \
- .setVectorSize(2) \
- .setLearningRate(0.01) \
- .setNumPartitions(2) \
- .setNumIterations(10) \
- .setSeed(1024) \
- .setMinCount(3) \
- .setWindowSize(6)
- self.assertEqual(model.vectorSize, 2)
- self.assertTrue(model.learningRate < 0.02)
- self.assertEqual(model.numPartitions, 2)
- self.assertEqual(model.numIterations, 10)
- self.assertEqual(model.seed, 1024)
- self.assertEqual(model.minCount, 3)
- self.assertEqual(model.windowSize, 6)
-
- def test_word2vec_get_vectors(self):
- data = [
- ["a", "b", "c", "d", "e", "f", "g"],
- ["a", "b", "c", "d", "e", "f"],
- ["a", "b", "c", "d", "e"],
- ["a", "b", "c", "d"],
- ["a", "b", "c"],
- ["a", "b"],
- ["a"]
- ]
- model = Word2Vec().fit(self.sc.parallelize(data))
- self.assertEqual(len(model.getVectors()), 3)
-
-
-class StandardScalerTests(MLlibTestCase):
- def test_model_setters(self):
- data = [
- [1.0, 2.0, 3.0],
- [2.0, 3.0, 4.0],
- [3.0, 4.0, 5.0]
- ]
- model = StandardScaler().fit(self.sc.parallelize(data))
- self.assertIsNotNone(model.setWithMean(True))
- self.assertIsNotNone(model.setWithStd(True))
- self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0]))
-
- def test_model_transform(self):
- data = [
- [1.0, 2.0, 3.0],
- [2.0, 3.0, 4.0],
- [3.0, 4.0, 5.0]
- ]
- model = StandardScaler().fit(self.sc.parallelize(data))
- self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0]))
-
-
-class ElementwiseProductTests(MLlibTestCase):
- def test_model_transform(self):
- weight = Vectors.dense([3, 2, 1])
-
- densevec = Vectors.dense([4, 5, 6])
- sparsevec = Vectors.sparse(3, [0], [1])
- eprod = ElementwiseProduct(weight)
- self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6]))
- self.assertEqual(
- eprod.transform(sparsevec), SparseVector(3, [0], [3]))
-
-
-class StreamingKMeansTest(MLLibStreamingTestCase):
- def test_model_params(self):
- """Test that the model params are set correctly"""
- stkm = StreamingKMeans()
- stkm.setK(5).setDecayFactor(0.0)
- self.assertEqual(stkm._k, 5)
- self.assertEqual(stkm._decayFactor, 0.0)
-
- # Model not set yet.
- self.assertIsNone(stkm.latestModel())
- self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0])
-
- stkm.setInitialCenters(
- centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0])
- self.assertEqual(
- stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]])
- self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0])
-
- def test_accuracy_for_single_center(self):
- """Test that parameters obtained are correct for a single center."""
- centers, batches = self.streamingKMeansDataGenerator(
- batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0)
- stkm = StreamingKMeans(1)
- stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.])
- input_stream = self.ssc.queueStream(
- [self.sc.parallelize(batch, 1) for batch in batches])
- stkm.trainOn(input_stream)
-
- self.ssc.start()
-
- def condition():
- self.assertEqual(stkm.latestModel().clusterWeights, [25.0])
- return True
- self._eventually(condition, catch_assertions=True)
-
- realCenters = array_sum(array(centers), axis=0)
- for i in range(5):
- modelCenters = stkm.latestModel().centers[0][i]
- self.assertAlmostEqual(centers[0][i], modelCenters, 1)
- self.assertAlmostEqual(realCenters[i], modelCenters, 1)
-
- def streamingKMeansDataGenerator(self, batches, numPoints,
- k, d, r, seed, centers=None):
- rng = random.RandomState(seed)
-
- # Generate centers.
- centers = [rng.randn(d) for i in range(k)]
-
- return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d))
- for j in range(numPoints)]
- for i in range(batches)]
-
- def test_trainOn_model(self):
- """Test the model on toy data with four clusters."""
- stkm = StreamingKMeans()
- initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]]
- stkm.setInitialCenters(
- centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0])
-
- # Create a toy dataset by setting a tiny offset for each point.
- offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]]
- batches = []
- for offset in offsets:
- batches.append([[offset[0] + center[0], offset[1] + center[1]]
- for center in initCenters])
-
- batches = [self.sc.parallelize(batch, 1) for batch in batches]
- input_stream = self.ssc.queueStream(batches)
- stkm.trainOn(input_stream)
- self.ssc.start()
-
- # Give enough time to train the model.
- def condition():
- finalModel = stkm.latestModel()
- self.assertTrue(all(finalModel.centers == array(initCenters)))
- self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0])
- return True
- self._eventually(condition, catch_assertions=True)
-
- def test_predictOn_model(self):
- """Test that the model predicts correctly on toy data."""
- stkm = StreamingKMeans()
- stkm._model = StreamingKMeansModel(
- clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]],
- clusterWeights=[1.0, 1.0, 1.0, 1.0])
-
- predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]]
- predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data]
- predict_stream = self.ssc.queueStream(predict_data)
- predict_val = stkm.predictOn(predict_stream)
-
- result = []
-
- def update(rdd):
- rdd_collect = rdd.collect()
- if rdd_collect:
- result.append(rdd_collect)
-
- predict_val.foreachRDD(update)
- self.ssc.start()
-
- def condition():
- self.assertEqual(result, [[0], [1], [2], [3]])
- return True
-
- self._eventually(condition, catch_assertions=True)
-
- @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark")
- def test_trainOn_predictOn(self):
- """Test that prediction happens on the updated model."""
- stkm = StreamingKMeans(decayFactor=0.0, k=2)
- stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0])
-
- # Since decay factor is set to zero, once the first batch
- # is passed the clusterCenters are updated to [-0.5, 0.7]
- # which causes 0.2 & 0.3 to be classified as 1, even though the
- # classification based in the initial model would have been 0
- # proving that the model is updated.
- batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]]
- batches = [self.sc.parallelize(batch) for batch in batches]
- input_stream = self.ssc.queueStream(batches)
- predict_results = []
-
- def collect(rdd):
- rdd_collect = rdd.collect()
- if rdd_collect:
- predict_results.append(rdd_collect)
-
- stkm.trainOn(input_stream)
- predict_stream = stkm.predictOn(input_stream)
- predict_stream.foreachRDD(collect)
-
- self.ssc.start()
-
- def condition():
- self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
- return True
-
- self._eventually(condition, catch_assertions=True)
-
-
-class LinearDataGeneratorTests(MLlibTestCase):
- def test_dim(self):
- linear_data = LinearDataGenerator.generateLinearInput(
- intercept=0.0, weights=[0.0, 0.0, 0.0],
- xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33],
- nPoints=4, seed=0, eps=0.1)
- self.assertEqual(len(linear_data), 4)
- for point in linear_data:
- self.assertEqual(len(point.features), 3)
-
- linear_data = LinearDataGenerator.generateLinearRDD(
- sc=self.sc, nexamples=6, nfeatures=2, eps=0.1,
- nParts=2, intercept=0.0).collect()
- self.assertEqual(len(linear_data), 6)
- for point in linear_data:
- self.assertEqual(len(point.features), 2)
-
-
-class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase):
-
- @staticmethod
- def generateLogisticInput(offset, scale, nPoints, seed):
- """
- Generate 1 / (1 + exp(-x * scale + offset))
-
- where,
- x is randomnly distributed and the threshold
- and labels for each sample in x is obtained from a random uniform
- distribution.
- """
- rng = random.RandomState(seed)
- x = rng.randn(nPoints)
- sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset)))
- y_p = rng.rand(nPoints)
- cut_off = y_p <= sigmoid
- y_p[cut_off] = 1.0
- y_p[~cut_off] = 0.0
- return [
- LabeledPoint(y_p[i], Vectors.dense([x[i]]))
- for i in range(nPoints)]
-
- @unittest.skip("Super flaky test")
- def test_parameter_accuracy(self):
- """
- Test that the final value of weights is close to the desired value.
- """
- input_batches = [
- self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
- for i in range(20)]
- input_stream = self.ssc.queueStream(input_batches)
-
- slr = StreamingLogisticRegressionWithSGD(
- stepSize=0.2, numIterations=25)
- slr.setInitialWeights([0.0])
- slr.trainOn(input_stream)
-
- self.ssc.start()
-
- def condition():
- rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5
- self.assertAlmostEqual(rel, 0.1, 1)
- return True
-
- self._eventually(condition, catch_assertions=True)
-
- def test_convergence(self):
- """
- Test that weights converge to the required value on toy data.
- """
- input_batches = [
- self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
- for i in range(20)]
- input_stream = self.ssc.queueStream(input_batches)
- models = []
-
- slr = StreamingLogisticRegressionWithSGD(
- stepSize=0.2, numIterations=25)
- slr.setInitialWeights([0.0])
- slr.trainOn(input_stream)
- input_stream.foreachRDD(
- lambda x: models.append(slr.latestModel().weights[0]))
-
- self.ssc.start()
-
- def condition():
- self.assertEqual(len(models), len(input_batches))
- return True
-
- # We want all batches to finish for this test.
- self._eventually(condition, 60.0, catch_assertions=True)
-
- t_models = array(models)
- diff = t_models[1:] - t_models[:-1]
- # Test that weights improve with a small tolerance
- self.assertTrue(all(diff >= -0.1))
- self.assertTrue(array_sum(diff > 0) > 1)
-
- @staticmethod
- def calculate_accuracy_error(true, predicted):
- return sum(abs(array(true) - array(predicted))) / len(true)
-
- def test_predictions(self):
- """Test predicted values on a toy model."""
- input_batches = []
- for i in range(20):
- batch = self.sc.parallelize(
- self.generateLogisticInput(0, 1.5, 100, 42 + i))
- input_batches.append(batch.map(lambda x: (x.label, x.features)))
- input_stream = self.ssc.queueStream(input_batches)
-
- slr = StreamingLogisticRegressionWithSGD(
- stepSize=0.2, numIterations=25)
- slr.setInitialWeights([1.5])
- predict_stream = slr.predictOnValues(input_stream)
- true_predicted = []
- predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect()))
- self.ssc.start()
-
- def condition():
- self.assertEqual(len(true_predicted), len(input_batches))
- return True
-
- self._eventually(condition, catch_assertions=True)
-
- # Test that the accuracy error is no more than 0.4 on each batch.
- for batch in true_predicted:
- true, predicted = zip(*batch)
- self.assertTrue(
- self.calculate_accuracy_error(true, predicted) < 0.4)
-
- @unittest.skip("Super flaky test")
- def test_training_and_prediction(self):
- """Test that the model improves on toy data with no. of batches"""
- input_batches = [
- self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
- for i in range(20)]
- predict_batches = [
- b.map(lambda lp: (lp.label, lp.features)) for b in input_batches]
-
- slr = StreamingLogisticRegressionWithSGD(
- stepSize=0.01, numIterations=25)
- slr.setInitialWeights([-0.1])
- errors = []
-
- def collect_errors(rdd):
- true, predicted = zip(*rdd.collect())
- errors.append(self.calculate_accuracy_error(true, predicted))
-
- true_predicted = []
- input_stream = self.ssc.queueStream(input_batches)
- predict_stream = self.ssc.queueStream(predict_batches)
- slr.trainOn(input_stream)
- ps = slr.predictOnValues(predict_stream)
- ps.foreachRDD(lambda x: collect_errors(x))
-
- self.ssc.start()
-
- def condition():
- # Test that the improvement in error is > 0.3
- if len(errors) == len(predict_batches):
- self.assertGreater(errors[1] - errors[-1], 0.3)
- if len(errors) >= 3 and errors[1] - errors[-1] > 0.3:
- return True
- return "Latest errors: " + ", ".join(map(lambda x: str(x), errors))
-
- self._eventually(condition)
-
-
-class StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
-
- def assertArrayAlmostEqual(self, array1, array2, dec):
- for i, j in array1, array2:
- self.assertAlmostEqual(i, j, dec)
-
- @unittest.skip("Super flaky test")
- def test_parameter_accuracy(self):
- """Test that coefs are predicted accurately by fitting on toy data."""
-
- # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients
- # (10, 10)
- slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
- slr.setInitialWeights([0.0, 0.0])
- xMean = [0.0, 0.0]
- xVariance = [1.0 / 3.0, 1.0 / 3.0]
-
- # Create ten batches with 100 sample points in each.
- batches = []
- for i in range(10):
- batch = LinearDataGenerator.generateLinearInput(
- 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1)
- batches.append(self.sc.parallelize(batch))
-
- input_stream = self.ssc.queueStream(batches)
- slr.trainOn(input_stream)
- self.ssc.start()
-
- def condition():
- self.assertArrayAlmostEqual(
- slr.latestModel().weights.array, [10., 10.], 1)
- self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1)
- return True
-
- self._eventually(condition, catch_assertions=True)
-
- def test_parameter_convergence(self):
- """Test that the model parameters improve with streaming data."""
- slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
- slr.setInitialWeights([0.0])
-
- # Create ten batches with 100 sample points in each.
- batches = []
- for i in range(10):
- batch = LinearDataGenerator.generateLinearInput(
- 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
- batches.append(self.sc.parallelize(batch))
-
- model_weights = []
- input_stream = self.ssc.queueStream(batches)
- input_stream.foreachRDD(
- lambda x: model_weights.append(slr.latestModel().weights[0]))
- slr.trainOn(input_stream)
- self.ssc.start()
-
- def condition():
- self.assertEqual(len(model_weights), len(batches))
- return True
-
- # We want all batches to finish for this test.
- self._eventually(condition, catch_assertions=True)
-
- w = array(model_weights)
- diff = w[1:] - w[:-1]
- self.assertTrue(all(diff >= -0.1))
-
- def test_prediction(self):
- """Test prediction on a model with weights already set."""
- # Create a model with initial Weights equal to coefs
- slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
- slr.setInitialWeights([10.0, 10.0])
-
- # Create ten batches with 100 sample points in each.
- batches = []
- for i in range(10):
- batch = LinearDataGenerator.generateLinearInput(
- 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0],
- 100, 42 + i, 0.1)
- batches.append(
- self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features)))
-
- input_stream = self.ssc.queueStream(batches)
- output_stream = slr.predictOnValues(input_stream)
- samples = []
- output_stream.foreachRDD(lambda x: samples.append(x.collect()))
-
- self.ssc.start()
-
- def condition():
- self.assertEqual(len(samples), len(batches))
- return True
-
- # We want all batches to finish for this test.
- self._eventually(condition, catch_assertions=True)
-
- # Test that mean absolute error on each batch is less than 0.1
- for batch in samples:
- true, predicted = zip(*batch)
- self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1)
-
- @unittest.skip("Super flaky test")
- def test_train_prediction(self):
- """Test that error on test data improves as model is trained."""
- slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
- slr.setInitialWeights([0.0])
-
- # Create ten batches with 100 sample points in each.
- batches = []
- for i in range(10):
- batch = LinearDataGenerator.generateLinearInput(
- 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
- batches.append(self.sc.parallelize(batch))
-
- predict_batches = [
- b.map(lambda lp: (lp.label, lp.features)) for b in batches]
- errors = []
-
- def func(rdd):
- true, predicted = zip(*rdd.collect())
- errors.append(mean(abs(true) - abs(predicted)))
-
- input_stream = self.ssc.queueStream(batches)
- output_stream = self.ssc.queueStream(predict_batches)
- slr.trainOn(input_stream)
- output_stream = slr.predictOnValues(output_stream)
- output_stream.foreachRDD(func)
- self.ssc.start()
-
- def condition():
- if len(errors) == len(predict_batches):
- self.assertGreater(errors[1] - errors[-1], 2)
- if len(errors) >= 3 and errors[1] - errors[-1] > 2:
- return True
- return "Latest errors: " + ", ".join(map(lambda x: str(x), errors))
-
- self._eventually(condition)
-
-
-class MLUtilsTests(MLlibTestCase):
- def test_append_bias(self):
- data = [2.0, 2.0, 2.0]
- ret = MLUtils.appendBias(data)
- self.assertEqual(ret[3], 1.0)
- self.assertEqual(type(ret), DenseVector)
-
- def test_append_bias_with_vector(self):
- data = Vectors.dense([2.0, 2.0, 2.0])
- ret = MLUtils.appendBias(data)
- self.assertEqual(ret[3], 1.0)
- self.assertEqual(type(ret), DenseVector)
-
- def test_append_bias_with_sp_vector(self):
- data = Vectors.sparse(3, {0: 2.0, 2: 2.0})
- expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0})
- # Returned value must be SparseVector
- ret = MLUtils.appendBias(data)
- self.assertEqual(ret, expected)
- self.assertEqual(type(ret), SparseVector)
-
- def test_load_vectors(self):
- import shutil
- data = [
- [1.0, 2.0, 3.0],
- [1.0, 2.0, 3.0]
- ]
- temp_dir = tempfile.mkdtemp()
- load_vectors_path = os.path.join(temp_dir, "test_load_vectors")
- try:
- self.sc.parallelize(data).saveAsTextFile(load_vectors_path)
- ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path)
- ret = ret_rdd.collect()
- self.assertEqual(len(ret), 2)
- self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0]))
- self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0]))
- except:
- self.fail()
- finally:
- shutil.rmtree(load_vectors_path)
-
-
-class ALSTests(MLlibTestCase):
-
- def test_als_ratings_serialize(self):
- r = Rating(7, 1123, 3.14)
- jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r)))
- nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr)))
- self.assertEqual(r.user, nr.user)
- self.assertEqual(r.product, nr.product)
- self.assertAlmostEqual(r.rating, nr.rating, 2)
-
- def test_als_ratings_id_long_error(self):
- r = Rating(1205640308657491975, 50233468418, 1.0)
- # rating user id exceeds max int value, should fail when pickled
- self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads,
- bytearray(ser.dumps(r)))
-
-
-class HashingTFTest(MLlibTestCase):
-
- def test_binary_term_freqs(self):
- hashingTF = HashingTF(100).setBinary(True)
- doc = "a a b c c c".split(" ")
- n = hashingTF.numFeatures
- output = hashingTF.transform(doc).toArray()
- expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0,
- hashingTF.indexOf("b"): 1.0,
- hashingTF.indexOf("c"): 1.0}).toArray()
- for i in range(0, n):
- self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) +
- ": expected " + str(expected[i]) + ", got " + str(output[i]))
-
-
-class DimensionalityReductionTests(MLlibTestCase):
-
- denseData = [
- Vectors.dense([0.0, 1.0, 2.0]),
- Vectors.dense([3.0, 4.0, 5.0]),
- Vectors.dense([6.0, 7.0, 8.0]),
- Vectors.dense([9.0, 0.0, 1.0])
- ]
- sparseData = [
- Vectors.sparse(3, [(1, 1.0), (2, 2.0)]),
- Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]),
- Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]),
- Vectors.sparse(3, [(0, 9.0), (2, 1.0)])
- ]
-
- def assertEqualUpToSign(self, vecA, vecB):
- eq1 = vecA - vecB
- eq2 = vecA + vecB
- self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6)
-
- def test_svd(self):
- denseMat = RowMatrix(self.sc.parallelize(self.denseData))
- sparseMat = RowMatrix(self.sc.parallelize(self.sparseData))
- m = 4
- n = 3
- for mat in [denseMat, sparseMat]:
- for k in range(1, 4):
- rm = mat.computeSVD(k, computeU=True)
- self.assertEqual(rm.s.size, k)
- self.assertEqual(rm.U.numRows(), m)
- self.assertEqual(rm.U.numCols(), k)
- self.assertEqual(rm.V.numRows, n)
- self.assertEqual(rm.V.numCols, k)
-
- # Test that U returned is None if computeU is set to False.
- self.assertEqual(mat.computeSVD(1).U, None)
-
- # Test that low rank matrices cannot have number of singular values
- # greater than a limit.
- rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1))))
- self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1)
-
- def test_pca(self):
- expected_pcs = array([
- [0.0, 1.0, 0.0],
- [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0],
- [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0]
- ])
- n = 3
- denseMat = RowMatrix(self.sc.parallelize(self.denseData))
- sparseMat = RowMatrix(self.sc.parallelize(self.sparseData))
- for mat in [denseMat, sparseMat]:
- for k in range(1, 4):
- pcs = mat.computePrincipalComponents(k)
- self.assertEqual(pcs.numRows, n)
- self.assertEqual(pcs.numCols, k)
-
- # We can just test the updated principal component for equality.
- self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1])
-
-
-class FPGrowthTest(MLlibTestCase):
-
- def test_fpgrowth(self):
- data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
- rdd = self.sc.parallelize(data, 2)
- model1 = FPGrowth.train(rdd, 0.6, 2)
- # use default data partition number when numPartitions is not specified
- model2 = FPGrowth.train(rdd, 0.6)
- self.assertEqual(sorted(model1.freqItemsets().collect()),
- sorted(model2.freqItemsets().collect()))
-
-if __name__ == "__main__":
- from pyspark.mllib.tests import *
- if not _have_scipy:
- print("NOTE: Skipping SciPy tests as it does not seem to be installed")
- runner = unishark.BufferedTestRunner(
- reporters=[unishark.XUnitReporter('target/test-reports/pyspark.mllib_{}'.format(
- os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))])
- unittest.main(testRunner=runner, verbosity=2)
- if not _have_scipy:
- print("NOTE: SciPy tests were skipped as it does not seem to be installed")
- sc.stop()
diff --git a/python/pyspark/mllib/tests/__init__.py b/python/pyspark/mllib/tests/__init__.py
new file mode 100644
index 0000000000000..cce3acad34a49
--- /dev/null
+++ b/python/pyspark/mllib/tests/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py
new file mode 100644
index 0000000000000..cc3b64b1cb284
--- /dev/null
+++ b/python/pyspark/mllib/tests/test_algorithms.py
@@ -0,0 +1,302 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+from shutil import rmtree
+import unittest
+
+from numpy import array, array_equal
+from py4j.protocol import Py4JJavaError
+
+from pyspark.mllib.fpm import FPGrowth
+from pyspark.mllib.recommendation import Rating
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.testing.mllibutils import make_serializer, MLlibTestCase
+
+
+ser = make_serializer()
+
+
+class ListTests(MLlibTestCase):
+
+ """
+ Test MLlib algorithms on plain lists, to make sure they're passed through
+ as NumPy arrays.
+ """
+
+ def test_bisecting_kmeans(self):
+ from pyspark.mllib.clustering import BisectingKMeans
+ data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2)
+ bskm = BisectingKMeans()
+ model = bskm.train(self.sc.parallelize(data, 2), k=4)
+ p = array([0.0, 0.0])
+ rdd_p = self.sc.parallelize([p])
+ self.assertEqual(model.predict(p), model.predict(rdd_p).first())
+ self.assertEqual(model.computeCost(p), model.computeCost(rdd_p))
+ self.assertEqual(model.k, len(model.clusterCenters))
+
+ def test_kmeans(self):
+ from pyspark.mllib.clustering import KMeans
+ data = [
+ [0, 1.1],
+ [0, 1.2],
+ [1.1, 0],
+ [1.2, 0],
+ ]
+ clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||",
+ initializationSteps=7, epsilon=1e-4)
+ self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1]))
+ self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3]))
+
+ def test_kmeans_deterministic(self):
+ from pyspark.mllib.clustering import KMeans
+ X = range(0, 100, 10)
+ Y = range(0, 100, 10)
+ data = [[x, y] for x, y in zip(X, Y)]
+ clusters1 = KMeans.train(self.sc.parallelize(data),
+ 3, initializationMode="k-means||",
+ seed=42, initializationSteps=7, epsilon=1e-4)
+ clusters2 = KMeans.train(self.sc.parallelize(data),
+ 3, initializationMode="k-means||",
+ seed=42, initializationSteps=7, epsilon=1e-4)
+ centers1 = clusters1.centers
+ centers2 = clusters2.centers
+ for c1, c2 in zip(centers1, centers2):
+ # TODO: Allow small numeric difference.
+ self.assertTrue(array_equal(c1, c2))
+
+ def test_gmm(self):
+ from pyspark.mllib.clustering import GaussianMixture
+ data = self.sc.parallelize([
+ [1, 2],
+ [8, 9],
+ [-4, -3],
+ [-6, -7],
+ ])
+ clusters = GaussianMixture.train(data, 2, convergenceTol=0.001,
+ maxIterations=10, seed=1)
+ labels = clusters.predict(data).collect()
+ self.assertEqual(labels[0], labels[1])
+ self.assertEqual(labels[2], labels[3])
+
+ def test_gmm_deterministic(self):
+ from pyspark.mllib.clustering import GaussianMixture
+ x = range(0, 100, 10)
+ y = range(0, 100, 10)
+ data = self.sc.parallelize([[a, b] for a, b in zip(x, y)])
+ clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001,
+ maxIterations=10, seed=63)
+ clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001,
+ maxIterations=10, seed=63)
+ for c1, c2 in zip(clusters1.weights, clusters2.weights):
+ self.assertEqual(round(c1, 7), round(c2, 7))
+
+ def test_gmm_with_initial_model(self):
+ from pyspark.mllib.clustering import GaussianMixture
+ data = self.sc.parallelize([
+ (-10, -5), (-9, -4), (10, 5), (9, 4)
+ ])
+
+ gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
+ maxIterations=10, seed=63)
+ gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
+ maxIterations=10, seed=63, initialModel=gmm1)
+ self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)
+
+ def test_classification(self):
+ from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
+ from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest, \
+ RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel
+ data = [
+ LabeledPoint(0.0, [1, 0, 0]),
+ LabeledPoint(1.0, [0, 1, 1]),
+ LabeledPoint(0.0, [2, 0, 0]),
+ LabeledPoint(1.0, [0, 2, 1])
+ ]
+ rdd = self.sc.parallelize(data)
+ features = [p.features.tolist() for p in data]
+
+ temp_dir = tempfile.mkdtemp()
+
+ lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10)
+ self.assertTrue(lr_model.predict(features[0]) <= 0)
+ self.assertTrue(lr_model.predict(features[1]) > 0)
+ self.assertTrue(lr_model.predict(features[2]) <= 0)
+ self.assertTrue(lr_model.predict(features[3]) > 0)
+
+ svm_model = SVMWithSGD.train(rdd, iterations=10)
+ self.assertTrue(svm_model.predict(features[0]) <= 0)
+ self.assertTrue(svm_model.predict(features[1]) > 0)
+ self.assertTrue(svm_model.predict(features[2]) <= 0)
+ self.assertTrue(svm_model.predict(features[3]) > 0)
+
+ nb_model = NaiveBayes.train(rdd)
+ self.assertTrue(nb_model.predict(features[0]) <= 0)
+ self.assertTrue(nb_model.predict(features[1]) > 0)
+ self.assertTrue(nb_model.predict(features[2]) <= 0)
+ self.assertTrue(nb_model.predict(features[3]) > 0)
+
+ categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
+ dt_model = DecisionTree.trainClassifier(
+ rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4)
+ self.assertTrue(dt_model.predict(features[0]) <= 0)
+ self.assertTrue(dt_model.predict(features[1]) > 0)
+ self.assertTrue(dt_model.predict(features[2]) <= 0)
+ self.assertTrue(dt_model.predict(features[3]) > 0)
+
+ dt_model_dir = os.path.join(temp_dir, "dt")
+ dt_model.save(self.sc, dt_model_dir)
+ same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir)
+ self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString())
+
+ rf_model = RandomForest.trainClassifier(
+ rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10,
+ maxBins=4, seed=1)
+ self.assertTrue(rf_model.predict(features[0]) <= 0)
+ self.assertTrue(rf_model.predict(features[1]) > 0)
+ self.assertTrue(rf_model.predict(features[2]) <= 0)
+ self.assertTrue(rf_model.predict(features[3]) > 0)
+
+ rf_model_dir = os.path.join(temp_dir, "rf")
+ rf_model.save(self.sc, rf_model_dir)
+ same_rf_model = RandomForestModel.load(self.sc, rf_model_dir)
+ self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString())
+
+ gbt_model = GradientBoostedTrees.trainClassifier(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4)
+ self.assertTrue(gbt_model.predict(features[0]) <= 0)
+ self.assertTrue(gbt_model.predict(features[1]) > 0)
+ self.assertTrue(gbt_model.predict(features[2]) <= 0)
+ self.assertTrue(gbt_model.predict(features[3]) > 0)
+
+ gbt_model_dir = os.path.join(temp_dir, "gbt")
+ gbt_model.save(self.sc, gbt_model_dir)
+ same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir)
+ self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString())
+
+ try:
+ rmtree(temp_dir)
+ except OSError:
+ pass
+
+ def test_regression(self):
+ from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
+ RidgeRegressionWithSGD
+ from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
+ data = [
+ LabeledPoint(-1.0, [0, -1]),
+ LabeledPoint(1.0, [0, 1]),
+ LabeledPoint(-1.0, [0, -2]),
+ LabeledPoint(1.0, [0, 2])
+ ]
+ rdd = self.sc.parallelize(data)
+ features = [p.features.tolist() for p in data]
+
+ lr_model = LinearRegressionWithSGD.train(rdd, iterations=10)
+ self.assertTrue(lr_model.predict(features[0]) <= 0)
+ self.assertTrue(lr_model.predict(features[1]) > 0)
+ self.assertTrue(lr_model.predict(features[2]) <= 0)
+ self.assertTrue(lr_model.predict(features[3]) > 0)
+
+ lasso_model = LassoWithSGD.train(rdd, iterations=10)
+ self.assertTrue(lasso_model.predict(features[0]) <= 0)
+ self.assertTrue(lasso_model.predict(features[1]) > 0)
+ self.assertTrue(lasso_model.predict(features[2]) <= 0)
+ self.assertTrue(lasso_model.predict(features[3]) > 0)
+
+ rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10)
+ self.assertTrue(rr_model.predict(features[0]) <= 0)
+ self.assertTrue(rr_model.predict(features[1]) > 0)
+ self.assertTrue(rr_model.predict(features[2]) <= 0)
+ self.assertTrue(rr_model.predict(features[3]) > 0)
+
+ categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
+ dt_model = DecisionTree.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4)
+ self.assertTrue(dt_model.predict(features[0]) <= 0)
+ self.assertTrue(dt_model.predict(features[1]) > 0)
+ self.assertTrue(dt_model.predict(features[2]) <= 0)
+ self.assertTrue(dt_model.predict(features[3]) > 0)
+
+ rf_model = RandomForest.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1)
+ self.assertTrue(rf_model.predict(features[0]) <= 0)
+ self.assertTrue(rf_model.predict(features[1]) > 0)
+ self.assertTrue(rf_model.predict(features[2]) <= 0)
+ self.assertTrue(rf_model.predict(features[3]) > 0)
+
+ gbt_model = GradientBoostedTrees.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4)
+ self.assertTrue(gbt_model.predict(features[0]) <= 0)
+ self.assertTrue(gbt_model.predict(features[1]) > 0)
+ self.assertTrue(gbt_model.predict(features[2]) <= 0)
+ self.assertTrue(gbt_model.predict(features[3]) > 0)
+
+ try:
+ LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10)
+ LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10)
+ RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10)
+ except ValueError:
+ self.fail()
+
+ # Verify that maxBins is being passed through
+ GradientBoostedTrees.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=32)
+ with self.assertRaises(Exception) as cm:
+ GradientBoostedTrees.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4, maxBins=1)
+
+
+class ALSTests(MLlibTestCase):
+
+ def test_als_ratings_serialize(self):
+ r = Rating(7, 1123, 3.14)
+ jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r)))
+ nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr)))
+ self.assertEqual(r.user, nr.user)
+ self.assertEqual(r.product, nr.product)
+ self.assertAlmostEqual(r.rating, nr.rating, 2)
+
+ def test_als_ratings_id_long_error(self):
+ r = Rating(1205640308657491975, 50233468418, 1.0)
+ # rating user id exceeds max int value, should fail when pickled
+ self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads,
+ bytearray(ser.dumps(r)))
+
+
+class FPGrowthTest(MLlibTestCase):
+
+ def test_fpgrowth(self):
+ data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]]
+ rdd = self.sc.parallelize(data, 2)
+ model1 = FPGrowth.train(rdd, 0.6, 2)
+ # use default data partition number when numPartitions is not specified
+ model2 = FPGrowth.train(rdd, 0.6)
+ self.assertEqual(sorted(model1.freqItemsets().collect()),
+ sorted(model2.freqItemsets().collect()))
+
+
+if __name__ == "__main__":
+ from pyspark.mllib.tests.test_algorithms import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/mllib/tests/test_feature.py b/python/pyspark/mllib/tests/test_feature.py
new file mode 100644
index 0000000000000..3da841c408558
--- /dev/null
+++ b/python/pyspark/mllib/tests/test_feature.py
@@ -0,0 +1,192 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from math import sqrt
+import unittest
+
+from numpy import array, random, exp, abs, tile
+
+from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, Vectors
+from pyspark.mllib.linalg.distributed import RowMatrix
+from pyspark.mllib.feature import HashingTF, IDF, StandardScaler, ElementwiseProduct, Word2Vec
+from pyspark.testing.mllibutils import MLlibTestCase
+
+
+class FeatureTest(MLlibTestCase):
+ def test_idf_model(self):
+ data = [
+ Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]),
+ Vectors.dense([1, 3, 0, 1, 3, 0, 0, 2, 0, 0, 1]),
+ Vectors.dense([1, 4, 1, 0, 0, 4, 9, 0, 1, 2, 0]),
+ Vectors.dense([2, 1, 0, 3, 0, 0, 5, 0, 2, 3, 9])
+ ]
+ model = IDF().fit(self.sc.parallelize(data, 2))
+ idf = model.idf()
+ self.assertEqual(len(idf), 11)
+
+
+class Word2VecTests(MLlibTestCase):
+ def test_word2vec_setters(self):
+ model = Word2Vec() \
+ .setVectorSize(2) \
+ .setLearningRate(0.01) \
+ .setNumPartitions(2) \
+ .setNumIterations(10) \
+ .setSeed(1024) \
+ .setMinCount(3) \
+ .setWindowSize(6)
+ self.assertEqual(model.vectorSize, 2)
+ self.assertTrue(model.learningRate < 0.02)
+ self.assertEqual(model.numPartitions, 2)
+ self.assertEqual(model.numIterations, 10)
+ self.assertEqual(model.seed, 1024)
+ self.assertEqual(model.minCount, 3)
+ self.assertEqual(model.windowSize, 6)
+
+ def test_word2vec_get_vectors(self):
+ data = [
+ ["a", "b", "c", "d", "e", "f", "g"],
+ ["a", "b", "c", "d", "e", "f"],
+ ["a", "b", "c", "d", "e"],
+ ["a", "b", "c", "d"],
+ ["a", "b", "c"],
+ ["a", "b"],
+ ["a"]
+ ]
+ model = Word2Vec().fit(self.sc.parallelize(data))
+ self.assertEqual(len(model.getVectors()), 3)
+
+
+class StandardScalerTests(MLlibTestCase):
+ def test_model_setters(self):
+ data = [
+ [1.0, 2.0, 3.0],
+ [2.0, 3.0, 4.0],
+ [3.0, 4.0, 5.0]
+ ]
+ model = StandardScaler().fit(self.sc.parallelize(data))
+ self.assertIsNotNone(model.setWithMean(True))
+ self.assertIsNotNone(model.setWithStd(True))
+ self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([-1.0, -1.0, -1.0]))
+
+ def test_model_transform(self):
+ data = [
+ [1.0, 2.0, 3.0],
+ [2.0, 3.0, 4.0],
+ [3.0, 4.0, 5.0]
+ ]
+ model = StandardScaler().fit(self.sc.parallelize(data))
+ self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0]))
+
+
+class ElementwiseProductTests(MLlibTestCase):
+ def test_model_transform(self):
+ weight = Vectors.dense([3, 2, 1])
+
+ densevec = Vectors.dense([4, 5, 6])
+ sparsevec = Vectors.sparse(3, [0], [1])
+ eprod = ElementwiseProduct(weight)
+ self.assertEqual(eprod.transform(densevec), DenseVector([12, 10, 6]))
+ self.assertEqual(
+ eprod.transform(sparsevec), SparseVector(3, [0], [3]))
+
+
+class HashingTFTest(MLlibTestCase):
+
+ def test_binary_term_freqs(self):
+ hashingTF = HashingTF(100).setBinary(True)
+ doc = "a a b c c c".split(" ")
+ n = hashingTF.numFeatures
+ output = hashingTF.transform(doc).toArray()
+ expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0,
+ hashingTF.indexOf("b"): 1.0,
+ hashingTF.indexOf("c"): 1.0}).toArray()
+ for i in range(0, n):
+ self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) +
+ ": expected " + str(expected[i]) + ", got " + str(output[i]))
+
+
+class DimensionalityReductionTests(MLlibTestCase):
+
+ denseData = [
+ Vectors.dense([0.0, 1.0, 2.0]),
+ Vectors.dense([3.0, 4.0, 5.0]),
+ Vectors.dense([6.0, 7.0, 8.0]),
+ Vectors.dense([9.0, 0.0, 1.0])
+ ]
+ sparseData = [
+ Vectors.sparse(3, [(1, 1.0), (2, 2.0)]),
+ Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]),
+ Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]),
+ Vectors.sparse(3, [(0, 9.0), (2, 1.0)])
+ ]
+
+ def assertEqualUpToSign(self, vecA, vecB):
+ eq1 = vecA - vecB
+ eq2 = vecA + vecB
+ self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6)
+
+ def test_svd(self):
+ denseMat = RowMatrix(self.sc.parallelize(self.denseData))
+ sparseMat = RowMatrix(self.sc.parallelize(self.sparseData))
+ m = 4
+ n = 3
+ for mat in [denseMat, sparseMat]:
+ for k in range(1, 4):
+ rm = mat.computeSVD(k, computeU=True)
+ self.assertEqual(rm.s.size, k)
+ self.assertEqual(rm.U.numRows(), m)
+ self.assertEqual(rm.U.numCols(), k)
+ self.assertEqual(rm.V.numRows, n)
+ self.assertEqual(rm.V.numCols, k)
+
+ # Test that U returned is None if computeU is set to False.
+ self.assertEqual(mat.computeSVD(1).U, None)
+
+ # Test that low rank matrices cannot have number of singular values
+ # greater than a limit.
+ rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1))))
+ self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1)
+
+ def test_pca(self):
+ expected_pcs = array([
+ [0.0, 1.0, 0.0],
+ [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0],
+ [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0]
+ ])
+ n = 3
+ denseMat = RowMatrix(self.sc.parallelize(self.denseData))
+ sparseMat = RowMatrix(self.sc.parallelize(self.sparseData))
+ for mat in [denseMat, sparseMat]:
+ for k in range(1, 4):
+ pcs = mat.computePrincipalComponents(k)
+ self.assertEqual(pcs.numRows, n)
+ self.assertEqual(pcs.numCols, k)
+
+ # We can just test the updated principal component for equality.
+ self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1])
+
+
+if __name__ == "__main__":
+ from pyspark.mllib.tests.test_feature import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py
new file mode 100644
index 0000000000000..d0ebd9bc3db79
--- /dev/null
+++ b/python/pyspark/mllib/tests/test_linalg.py
@@ -0,0 +1,633 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import array as pyarray
+import unittest
+
+from numpy import array, array_equal, zeros, arange, tile, ones, inf
+
+import pyspark.ml.linalg as newlinalg
+from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \
+ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.testing.mllibutils import make_serializer, MLlibTestCase
+
+_have_scipy = False
+try:
+ import scipy.sparse
+ _have_scipy = True
+except:
+ # No SciPy, but that's okay, we'll skip those tests
+ pass
+
+
+ser = make_serializer()
+
+
+def _squared_distance(a, b):
+ if isinstance(a, Vector):
+ return a.squared_distance(b)
+ else:
+ return b.squared_distance(a)
+
+
+class VectorTests(MLlibTestCase):
+
+ def _test_serialize(self, v):
+ self.assertEqual(v, ser.loads(ser.dumps(v)))
+ jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v)))
+ nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec)))
+ self.assertEqual(v, nv)
+ vs = [v] * 100
+ jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs)))
+ nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs)))
+ self.assertEqual(vs, nvs)
+
+ def test_serialize(self):
+ self._test_serialize(DenseVector(range(10)))
+ self._test_serialize(DenseVector(array([1., 2., 3., 4.])))
+ self._test_serialize(DenseVector(pyarray.array('d', range(10))))
+ self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
+ self._test_serialize(SparseVector(3, {}))
+ self._test_serialize(DenseMatrix(2, 3, range(6)))
+ sm1 = SparseMatrix(
+ 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
+ self._test_serialize(sm1)
+
+ def test_dot(self):
+ sv = SparseVector(4, {1: 1, 3: 2})
+ dv = DenseVector(array([1., 2., 3., 4.]))
+ lst = DenseVector([1, 2, 3, 4])
+ mat = array([[1., 2., 3., 4.],
+ [1., 2., 3., 4.],
+ [1., 2., 3., 4.],
+ [1., 2., 3., 4.]])
+ arr = pyarray.array('d', [0, 1, 2, 3])
+ self.assertEqual(10.0, sv.dot(dv))
+ self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat)))
+ self.assertEqual(30.0, dv.dot(dv))
+ self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat)))
+ self.assertEqual(30.0, lst.dot(dv))
+ self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat)))
+ self.assertEqual(7.0, sv.dot(arr))
+
+ def test_squared_distance(self):
+ sv = SparseVector(4, {1: 1, 3: 2})
+ dv = DenseVector(array([1., 2., 3., 4.]))
+ lst = DenseVector([4, 3, 2, 1])
+ lst1 = [4, 3, 2, 1]
+ arr = pyarray.array('d', [0, 2, 1, 3])
+ narr = array([0, 2, 1, 3])
+ self.assertEqual(15.0, _squared_distance(sv, dv))
+ self.assertEqual(25.0, _squared_distance(sv, lst))
+ self.assertEqual(20.0, _squared_distance(dv, lst))
+ self.assertEqual(15.0, _squared_distance(dv, sv))
+ self.assertEqual(25.0, _squared_distance(lst, sv))
+ self.assertEqual(20.0, _squared_distance(lst, dv))
+ self.assertEqual(0.0, _squared_distance(sv, sv))
+ self.assertEqual(0.0, _squared_distance(dv, dv))
+ self.assertEqual(0.0, _squared_distance(lst, lst))
+ self.assertEqual(25.0, _squared_distance(sv, lst1))
+ self.assertEqual(3.0, _squared_distance(sv, arr))
+ self.assertEqual(3.0, _squared_distance(sv, narr))
+
+ def test_hash(self):
+ v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
+ v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v4 = SparseVector(4, [(1, 1.0), (3, 2.5)])
+ self.assertEqual(hash(v1), hash(v2))
+ self.assertEqual(hash(v1), hash(v3))
+ self.assertEqual(hash(v2), hash(v3))
+ self.assertFalse(hash(v1) == hash(v4))
+ self.assertFalse(hash(v2) == hash(v4))
+
+ def test_eq(self):
+ v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
+ v3 = DenseVector([0.0, 1.0, 0.0, 5.5])
+ v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
+ v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
+ v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
+ self.assertEqual(v1, v2)
+ self.assertEqual(v1, v3)
+ self.assertFalse(v2 == v4)
+ self.assertFalse(v1 == v5)
+ self.assertFalse(v1 == v6)
+
+ def test_equals(self):
+ indices = [1, 2, 4]
+ values = [1., 3., 2.]
+ self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.]))
+ self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.]))
+
+ def test_conversion(self):
+ # numpy arrays should be automatically upcast to float64
+ # tests for fix of [SPARK-5089]
+ v = array([1, 2, 3, 4], dtype='float64')
+ dv = DenseVector(v)
+ self.assertTrue(dv.array.dtype == 'float64')
+ v = array([1, 2, 3, 4], dtype='float32')
+ dv = DenseVector(v)
+ self.assertTrue(dv.array.dtype == 'float64')
+
+ def test_sparse_vector_indexing(self):
+ sv = SparseVector(5, {1: 1, 3: 2})
+ self.assertEqual(sv[0], 0.)
+ self.assertEqual(sv[3], 2.)
+ self.assertEqual(sv[1], 1.)
+ self.assertEqual(sv[2], 0.)
+ self.assertEqual(sv[4], 0.)
+ self.assertEqual(sv[-1], 0.)
+ self.assertEqual(sv[-2], 2.)
+ self.assertEqual(sv[-3], 0.)
+ self.assertEqual(sv[-5], 0.)
+ for ind in [5, -6]:
+ self.assertRaises(IndexError, sv.__getitem__, ind)
+ for ind in [7.8, '1']:
+ self.assertRaises(TypeError, sv.__getitem__, ind)
+
+ zeros = SparseVector(4, {})
+ self.assertEqual(zeros[0], 0.0)
+ self.assertEqual(zeros[3], 0.0)
+ for ind in [4, -5]:
+ self.assertRaises(IndexError, zeros.__getitem__, ind)
+
+ empty = SparseVector(0, {})
+ for ind in [-1, 0, 1]:
+ self.assertRaises(IndexError, empty.__getitem__, ind)
+
+ def test_sparse_vector_iteration(self):
+ self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0])
+ self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0])
+
+ def test_matrix_indexing(self):
+ mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
+ expected = [[0, 6], [1, 8], [4, 10]]
+ for i in range(3):
+ for j in range(2):
+ self.assertEqual(mat[i, j], expected[i][j])
+
+ for i, j in [(-1, 0), (4, 1), (3, 4)]:
+ self.assertRaises(IndexError, mat.__getitem__, (i, j))
+
+ def test_repr_dense_matrix(self):
+ mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10])
+ self.assertTrue(
+ repr(mat),
+ 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
+
+ mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True)
+ self.assertTrue(
+ repr(mat),
+ 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)')
+
+ mat = DenseMatrix(6, 3, zeros(18))
+ self.assertTrue(
+ repr(mat),
+ 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)')
+
+ def test_repr_sparse_matrix(self):
+ sm1t = SparseMatrix(
+ 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
+ isTransposed=True)
+ self.assertTrue(
+ repr(sm1t),
+ 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)')
+
+ indices = tile(arange(6), 3)
+ values = ones(18)
+ sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values)
+ self.assertTrue(
+ repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \
+ [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \
+ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)")
+
+ self.assertTrue(
+ str(sm),
+ "6 X 3 CSCMatrix\n\
+ (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\
+ (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\
+ (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..")
+
+ sm = SparseMatrix(1, 18, zeros(19), [], [])
+ self.assertTrue(
+ repr(sm),
+ 'SparseMatrix(1, 18, \
+ [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)')
+
+ def test_sparse_matrix(self):
+ # Test sparse matrix creation.
+ sm1 = SparseMatrix(
+ 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
+ self.assertEqual(sm1.numRows, 3)
+ self.assertEqual(sm1.numCols, 4)
+ self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4])
+ self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2])
+ self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0])
+ self.assertTrue(
+ repr(sm1),
+ 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)')
+
+ # Test indexing
+ expected = [
+ [0, 0, 0, 0],
+ [1, 0, 4, 0],
+ [2, 0, 5, 0]]
+
+ for i in range(3):
+ for j in range(4):
+ self.assertEqual(expected[i][j], sm1[i, j])
+ self.assertTrue(array_equal(sm1.toArray(), expected))
+
+ for i, j in [(-1, 1), (4, 3), (3, 5)]:
+ self.assertRaises(IndexError, sm1.__getitem__, (i, j))
+
+ # Test conversion to dense and sparse.
+ smnew = sm1.toDense().toSparse()
+ self.assertEqual(sm1.numRows, smnew.numRows)
+ self.assertEqual(sm1.numCols, smnew.numCols)
+ self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs))
+ self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices))
+ self.assertTrue(array_equal(sm1.values, smnew.values))
+
+ sm1t = SparseMatrix(
+ 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0],
+ isTransposed=True)
+ self.assertEqual(sm1t.numRows, 3)
+ self.assertEqual(sm1t.numCols, 4)
+ self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5])
+ self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2])
+ self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0])
+
+ expected = [
+ [3, 2, 0, 0],
+ [0, 0, 4, 0],
+ [9, 0, 8, 0]]
+
+ for i in range(3):
+ for j in range(4):
+ self.assertEqual(expected[i][j], sm1t[i, j])
+ self.assertTrue(array_equal(sm1t.toArray(), expected))
+
+ def test_dense_matrix_is_transposed(self):
+ mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True)
+ mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9])
+ self.assertEqual(mat1, mat)
+
+ expected = [[0, 4], [1, 6], [3, 9]]
+ for i in range(3):
+ for j in range(2):
+ self.assertEqual(mat1[i, j], expected[i][j])
+ self.assertTrue(array_equal(mat1.toArray(), expected))
+
+ sm = mat1.toSparse()
+ self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2]))
+ self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
+ self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
+
+ def test_parse_vector(self):
+ a = DenseVector([])
+ self.assertEqual(str(a), '[]')
+ self.assertEqual(Vectors.parse(str(a)), a)
+ a = DenseVector([3, 4, 6, 7])
+ self.assertEqual(str(a), '[3.0,4.0,6.0,7.0]')
+ self.assertEqual(Vectors.parse(str(a)), a)
+ a = SparseVector(4, [], [])
+ self.assertEqual(str(a), '(4,[],[])')
+ self.assertEqual(SparseVector.parse(str(a)), a)
+ a = SparseVector(4, [0, 2], [3, 4])
+ self.assertEqual(str(a), '(4,[0,2],[3.0,4.0])')
+ self.assertEqual(Vectors.parse(str(a)), a)
+ a = SparseVector(10, [0, 1], [4, 5])
+ self.assertEqual(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
+
+ def test_norms(self):
+ a = DenseVector([0, 2, 3, -1])
+ self.assertAlmostEqual(a.norm(2), 3.742, 3)
+ self.assertTrue(a.norm(1), 6)
+ self.assertTrue(a.norm(inf), 3)
+ a = SparseVector(4, [0, 2], [3, -4])
+ self.assertAlmostEqual(a.norm(2), 5)
+ self.assertTrue(a.norm(1), 7)
+ self.assertTrue(a.norm(inf), 4)
+
+ tmp = SparseVector(4, [0, 2], [3, 0])
+ self.assertEqual(tmp.numNonzeros(), 1)
+
+ def test_ml_mllib_vector_conversion(self):
+ # to ml
+ # dense
+ mllibDV = Vectors.dense([1, 2, 3])
+ mlDV1 = newlinalg.Vectors.dense([1, 2, 3])
+ mlDV2 = mllibDV.asML()
+ self.assertEqual(mlDV2, mlDV1)
+ # sparse
+ mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5})
+ mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5})
+ mlSV2 = mllibSV.asML()
+ self.assertEqual(mlSV2, mlSV1)
+ # from ml
+ # dense
+ mllibDV1 = Vectors.dense([1, 2, 3])
+ mlDV = newlinalg.Vectors.dense([1, 2, 3])
+ mllibDV2 = Vectors.fromML(mlDV)
+ self.assertEqual(mllibDV1, mllibDV2)
+ # sparse
+ mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5})
+ mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5})
+ mllibSV2 = Vectors.fromML(mlSV)
+ self.assertEqual(mllibSV1, mllibSV2)
+
+ def test_ml_mllib_matrix_conversion(self):
+ # to ml
+ # dense
+ mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3])
+ mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3])
+ mlDM2 = mllibDM.asML()
+ self.assertEqual(mlDM2, mlDM1)
+ # transposed
+ mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True)
+ mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True)
+ mlDMt2 = mllibDMt.asML()
+ self.assertEqual(mlDMt2, mlDMt1)
+ # sparse
+ mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ mlSM2 = mllibSM.asML()
+ self.assertEqual(mlSM2, mlSM1)
+ # transposed
+ mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+ mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+ mlSMt2 = mllibSMt.asML()
+ self.assertEqual(mlSMt2, mlSMt1)
+ # from ml
+ # dense
+ mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4])
+ mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4])
+ mllibDM2 = Matrices.fromML(mlDM)
+ self.assertEqual(mllibDM1, mllibDM2)
+ # transposed
+ mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True)
+ mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True)
+ mllibDMt2 = Matrices.fromML(mlDMt)
+ self.assertEqual(mllibDMt1, mllibDMt2)
+ # sparse
+ mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ mllibSM2 = Matrices.fromML(mlSM)
+ self.assertEqual(mllibSM1, mllibSM2)
+ # transposed
+ mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+ mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+ mllibSMt2 = Matrices.fromML(mlSMt)
+ self.assertEqual(mllibSMt1, mllibSMt2)
+
+
+class VectorUDTTests(MLlibTestCase):
+
+ dv0 = DenseVector([])
+ dv1 = DenseVector([1.0, 2.0])
+ sv0 = SparseVector(2, [], [])
+ sv1 = SparseVector(2, [1], [2.0])
+ udt = VectorUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for v in [self.dv0, self.dv1, self.sv0, self.sv1]:
+ self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
+
+ def test_infer_schema(self):
+ rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
+ df = rdd.toDF()
+ schema = df.schema
+ field = [f for f in schema.fields if f.name == "features"][0]
+ self.assertEqual(field.dataType, self.udt)
+ vectors = df.rdd.map(lambda p: p.features).collect()
+ self.assertEqual(len(vectors), 2)
+ for v in vectors:
+ if isinstance(v, SparseVector):
+ self.assertEqual(v, self.sv1)
+ elif isinstance(v, DenseVector):
+ self.assertEqual(v, self.dv1)
+ else:
+ raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
+
+
+class MatrixUDTTests(MLlibTestCase):
+
+ dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
+ dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True)
+ sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0])
+ sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True)
+ udt = MatrixUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for m in [self.dm1, self.dm2, self.sm1, self.sm2]:
+ self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
+
+ def test_infer_schema(self):
+ rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
+ df = rdd.toDF()
+ schema = df.schema
+ self.assertTrue(schema.fields[1].dataType, self.udt)
+ matrices = df.rdd.map(lambda x: x._2).collect()
+ self.assertEqual(len(matrices), 2)
+ for m in matrices:
+ if isinstance(m, DenseMatrix):
+ self.assertTrue(m, self.dm1)
+ elif isinstance(m, SparseMatrix):
+ self.assertTrue(m, self.sm1)
+ else:
+ raise ValueError("Expected a matrix but got type %r" % type(m))
+
+
+@unittest.skipIf(not _have_scipy, "SciPy not installed")
+class SciPyTests(MLlibTestCase):
+
+ """
+ Test both vector operations and MLlib algorithms with SciPy sparse matrices,
+ if SciPy is available.
+ """
+
+ def test_serialize(self):
+ from scipy.sparse import lil_matrix
+ lil = lil_matrix((4, 1))
+ lil[1, 0] = 1
+ lil[3, 0] = 2
+ sv = SparseVector(4, {1: 1, 3: 2})
+ self.assertEqual(sv, _convert_to_vector(lil))
+ self.assertEqual(sv, _convert_to_vector(lil.tocsc()))
+ self.assertEqual(sv, _convert_to_vector(lil.tocoo()))
+ self.assertEqual(sv, _convert_to_vector(lil.tocsr()))
+ self.assertEqual(sv, _convert_to_vector(lil.todok()))
+
+ def serialize(l):
+ return ser.loads(ser.dumps(_convert_to_vector(l)))
+ self.assertEqual(sv, serialize(lil))
+ self.assertEqual(sv, serialize(lil.tocsc()))
+ self.assertEqual(sv, serialize(lil.tocsr()))
+ self.assertEqual(sv, serialize(lil.todok()))
+
+ def test_convert_to_vector(self):
+ from scipy.sparse import csc_matrix
+ # Create a CSC matrix with non-sorted indices
+ indptr = array([0, 2])
+ indices = array([3, 1])
+ data = array([2.0, 1.0])
+ csc = csc_matrix((data, indices, indptr))
+ self.assertFalse(csc.has_sorted_indices)
+ sv = SparseVector(4, {1: 1, 3: 2})
+ self.assertEqual(sv, _convert_to_vector(csc))
+
+ def test_dot(self):
+ from scipy.sparse import lil_matrix
+ lil = lil_matrix((4, 1))
+ lil[1, 0] = 1
+ lil[3, 0] = 2
+ dv = DenseVector(array([1., 2., 3., 4.]))
+ self.assertEqual(10.0, dv.dot(lil))
+
+ def test_squared_distance(self):
+ from scipy.sparse import lil_matrix
+ lil = lil_matrix((4, 1))
+ lil[1, 0] = 3
+ lil[3, 0] = 2
+ dv = DenseVector(array([1., 2., 3., 4.]))
+ sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4})
+ self.assertEqual(15.0, dv.squared_distance(lil))
+ self.assertEqual(15.0, sv.squared_distance(lil))
+
+ def scipy_matrix(self, size, values):
+ """Create a column SciPy matrix from a dictionary of values"""
+ from scipy.sparse import lil_matrix
+ lil = lil_matrix((size, 1))
+ for key, value in values.items():
+ lil[key, 0] = value
+ return lil
+
+ def test_clustering(self):
+ from pyspark.mllib.clustering import KMeans
+ data = [
+ self.scipy_matrix(3, {1: 1.0}),
+ self.scipy_matrix(3, {1: 1.1}),
+ self.scipy_matrix(3, {2: 1.0}),
+ self.scipy_matrix(3, {2: 1.1})
+ ]
+ clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||")
+ self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1]))
+ self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3]))
+
+ def test_classification(self):
+ from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
+ from pyspark.mllib.tree import DecisionTree
+ data = [
+ LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})),
+ LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})),
+ LabeledPoint(0.0, self.scipy_matrix(2, {0: 2.0})),
+ LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0}))
+ ]
+ rdd = self.sc.parallelize(data)
+ features = [p.features for p in data]
+
+ lr_model = LogisticRegressionWithSGD.train(rdd)
+ self.assertTrue(lr_model.predict(features[0]) <= 0)
+ self.assertTrue(lr_model.predict(features[1]) > 0)
+ self.assertTrue(lr_model.predict(features[2]) <= 0)
+ self.assertTrue(lr_model.predict(features[3]) > 0)
+
+ svm_model = SVMWithSGD.train(rdd)
+ self.assertTrue(svm_model.predict(features[0]) <= 0)
+ self.assertTrue(svm_model.predict(features[1]) > 0)
+ self.assertTrue(svm_model.predict(features[2]) <= 0)
+ self.assertTrue(svm_model.predict(features[3]) > 0)
+
+ nb_model = NaiveBayes.train(rdd)
+ self.assertTrue(nb_model.predict(features[0]) <= 0)
+ self.assertTrue(nb_model.predict(features[1]) > 0)
+ self.assertTrue(nb_model.predict(features[2]) <= 0)
+ self.assertTrue(nb_model.predict(features[3]) > 0)
+
+ categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
+ dt_model = DecisionTree.trainClassifier(rdd, numClasses=2,
+ categoricalFeaturesInfo=categoricalFeaturesInfo)
+ self.assertTrue(dt_model.predict(features[0]) <= 0)
+ self.assertTrue(dt_model.predict(features[1]) > 0)
+ self.assertTrue(dt_model.predict(features[2]) <= 0)
+ self.assertTrue(dt_model.predict(features[3]) > 0)
+
+ def test_regression(self):
+ from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
+ RidgeRegressionWithSGD
+ from pyspark.mllib.tree import DecisionTree
+ data = [
+ LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})),
+ LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})),
+ LabeledPoint(-1.0, self.scipy_matrix(2, {1: -2.0})),
+ LabeledPoint(1.0, self.scipy_matrix(2, {1: 2.0}))
+ ]
+ rdd = self.sc.parallelize(data)
+ features = [p.features for p in data]
+
+ lr_model = LinearRegressionWithSGD.train(rdd)
+ self.assertTrue(lr_model.predict(features[0]) <= 0)
+ self.assertTrue(lr_model.predict(features[1]) > 0)
+ self.assertTrue(lr_model.predict(features[2]) <= 0)
+ self.assertTrue(lr_model.predict(features[3]) > 0)
+
+ lasso_model = LassoWithSGD.train(rdd)
+ self.assertTrue(lasso_model.predict(features[0]) <= 0)
+ self.assertTrue(lasso_model.predict(features[1]) > 0)
+ self.assertTrue(lasso_model.predict(features[2]) <= 0)
+ self.assertTrue(lasso_model.predict(features[3]) > 0)
+
+ rr_model = RidgeRegressionWithSGD.train(rdd)
+ self.assertTrue(rr_model.predict(features[0]) <= 0)
+ self.assertTrue(rr_model.predict(features[1]) > 0)
+ self.assertTrue(rr_model.predict(features[2]) <= 0)
+ self.assertTrue(rr_model.predict(features[3]) > 0)
+
+ categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
+ dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
+ self.assertTrue(dt_model.predict(features[0]) <= 0)
+ self.assertTrue(dt_model.predict(features[1]) > 0)
+ self.assertTrue(dt_model.predict(features[2]) <= 0)
+ self.assertTrue(dt_model.predict(features[3]) > 0)
+
+
+if __name__ == "__main__":
+ from pyspark.mllib.tests.test_linalg import *
+ if not _have_scipy:
+ print("NOTE: Skipping SciPy tests as it does not seem to be installed")
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
+ if not _have_scipy:
+ print("NOTE: SciPy tests were skipped as it does not seem to be installed")
diff --git a/python/pyspark/mllib/tests/test_stat.py b/python/pyspark/mllib/tests/test_stat.py
new file mode 100644
index 0000000000000..f23ae291d317a
--- /dev/null
+++ b/python/pyspark/mllib/tests/test_stat.py
@@ -0,0 +1,188 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import array as pyarray
+import unittest
+
+from numpy import array
+
+from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector, \
+ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
+from pyspark.mllib.random import RandomRDDs
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.stat import Statistics
+from pyspark.sql.utils import IllegalArgumentException
+from pyspark.testing.mllibutils import MLlibTestCase
+
+
+class StatTests(MLlibTestCase):
+ # SPARK-4023
+ def test_col_with_different_rdds(self):
+ # numpy
+ data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(1000, summary.count())
+ # array
+ data = self.sc.parallelize([range(10)] * 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(10, summary.count())
+ # array
+ data = self.sc.parallelize([pyarray.array("d", range(10))] * 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(10, summary.count())
+
+ def test_col_norms(self):
+ data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(10, len(summary.normL1()))
+ self.assertEqual(10, len(summary.normL2()))
+
+ data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x))
+ summary2 = Statistics.colStats(data2)
+ self.assertEqual(array([45.0]), summary2.normL1())
+ import math
+ expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10))))
+ self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14)
+
+
+class ChiSqTestTests(MLlibTestCase):
+ def test_goodness_of_fit(self):
+ from numpy import inf
+
+ observed = Vectors.dense([4, 6, 5])
+ pearson = Statistics.chiSqTest(observed)
+
+ # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))`
+ self.assertEqual(pearson.statistic, 0.4)
+ self.assertEqual(pearson.degreesOfFreedom, 2)
+ self.assertAlmostEqual(pearson.pValue, 0.8187, 4)
+
+ # Different expected and observed sum
+ observed1 = Vectors.dense([21, 38, 43, 80])
+ expected1 = Vectors.dense([3, 5, 7, 20])
+ pearson1 = Statistics.chiSqTest(observed1, expected1)
+
+ # Results validated against the R command
+ # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))`
+ self.assertAlmostEqual(pearson1.statistic, 14.1429, 4)
+ self.assertEqual(pearson1.degreesOfFreedom, 3)
+ self.assertAlmostEqual(pearson1.pValue, 0.002717, 4)
+
+ # Vectors with different sizes
+ observed3 = Vectors.dense([1.0, 2.0, 3.0])
+ expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0])
+ self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3)
+
+ # Negative counts in observed
+ neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0])
+ self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1)
+
+ # Count = 0.0 in expected but not observed
+ zero_expected = Vectors.dense([1.0, 0.0, 3.0])
+ pearson_inf = Statistics.chiSqTest(observed, zero_expected)
+ self.assertEqual(pearson_inf.statistic, inf)
+ self.assertEqual(pearson_inf.degreesOfFreedom, 2)
+ self.assertEqual(pearson_inf.pValue, 0.0)
+
+ # 0.0 in expected and observed simultaneously
+ zero_observed = Vectors.dense([2.0, 0.0, 1.0])
+ self.assertRaises(
+ IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected)
+
+ def test_matrix_independence(self):
+ data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0]
+ chi = Statistics.chiSqTest(Matrices.dense(3, 4, data))
+
+ # Results validated against R command
+ # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))`
+ self.assertAlmostEqual(chi.statistic, 21.9958, 4)
+ self.assertEqual(chi.degreesOfFreedom, 6)
+ self.assertAlmostEqual(chi.pValue, 0.001213, 4)
+
+ # Negative counts
+ neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0])
+ self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts)
+
+ # Row sum = 0.0
+ row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0])
+ self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero)
+
+ # Column sum = 0.0
+ col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0])
+ self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero)
+
+ def test_chi_sq_pearson(self):
+ data = [
+ LabeledPoint(0.0, Vectors.dense([0.5, 10.0])),
+ LabeledPoint(0.0, Vectors.dense([1.5, 20.0])),
+ LabeledPoint(1.0, Vectors.dense([1.5, 30.0])),
+ LabeledPoint(0.0, Vectors.dense([3.5, 30.0])),
+ LabeledPoint(0.0, Vectors.dense([3.5, 40.0])),
+ LabeledPoint(1.0, Vectors.dense([3.5, 40.0]))
+ ]
+
+ for numParts in [2, 4, 6, 8]:
+ chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts))
+ feature1 = chi[0]
+ self.assertEqual(feature1.statistic, 0.75)
+ self.assertEqual(feature1.degreesOfFreedom, 2)
+ self.assertAlmostEqual(feature1.pValue, 0.6873, 4)
+
+ feature2 = chi[1]
+ self.assertEqual(feature2.statistic, 1.5)
+ self.assertEqual(feature2.degreesOfFreedom, 3)
+ self.assertAlmostEqual(feature2.pValue, 0.6823, 4)
+
+ def test_right_number_of_results(self):
+ num_cols = 1001
+ sparse_data = [
+ LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])),
+ LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)]))
+ ]
+ chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data))
+ self.assertEqual(len(chi), num_cols)
+ self.assertIsNotNone(chi[1000])
+
+
+class KolmogorovSmirnovTest(MLlibTestCase):
+
+ def test_R_implementation_equivalence(self):
+ data = self.sc.parallelize([
+ 1.1626852897838, -0.585924465893051, 1.78546500331661, -1.33259371048501,
+ -0.446566766553219, 0.569606122374976, -2.88971761441412, -0.869018343326555,
+ -0.461702683149641, -0.555540910137444, -0.0201353678515895, -0.150382224136063,
+ -0.628126755843964, 1.32322085193283, -1.52135057001199, -0.437427868856691,
+ 0.970577579543399, 0.0282226444247749, -0.0857821886527593, 0.389214404984942
+ ])
+ model = Statistics.kolmogorovSmirnovTest(data, "norm")
+ self.assertAlmostEqual(model.statistic, 0.189, 3)
+ self.assertAlmostEqual(model.pValue, 0.422, 3)
+
+ model = Statistics.kolmogorovSmirnovTest(data, "norm", 0, 1)
+ self.assertAlmostEqual(model.statistic, 0.189, 3)
+ self.assertAlmostEqual(model.pValue, 0.422, 3)
+
+
+if __name__ == "__main__":
+ from pyspark.mllib.tests.test_stat import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py
new file mode 100644
index 0000000000000..4bc8904acd31c
--- /dev/null
+++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py
@@ -0,0 +1,514 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from time import time, sleep
+import unittest
+
+from numpy import array, random, exp, dot, all, mean, abs
+from numpy import sum as array_sum
+
+from pyspark import SparkContext
+from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
+from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
+from pyspark.mllib.linalg import Vectors
+from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
+from pyspark.mllib.util import LinearDataGenerator
+from pyspark.streaming import StreamingContext
+
+
+class MLLibStreamingTestCase(unittest.TestCase):
+ def setUp(self):
+ self.sc = SparkContext('local[4]', "MLlib tests")
+ self.ssc = StreamingContext(self.sc, 1.0)
+
+ def tearDown(self):
+ self.ssc.stop(False)
+ self.sc.stop()
+
+ @staticmethod
+ def _eventually(condition, timeout=30.0, catch_assertions=False):
+ """
+ Wait a given amount of time for a condition to pass, else fail with an error.
+ This is a helper utility for streaming ML tests.
+ :param condition: Function that checks for termination conditions.
+ condition() can return:
+ - True: Conditions met. Return without error.
+ - other value: Conditions not met yet. Continue. Upon timeout,
+ include last such value in error message.
+ Note that this method may be called at any time during
+ streaming execution (e.g., even before any results
+ have been created).
+ :param timeout: Number of seconds to wait. Default 30 seconds.
+ :param catch_assertions: If False (default), do not catch AssertionErrors.
+ If True, catch AssertionErrors; continue, but save
+ error to throw upon timeout.
+ """
+ start_time = time()
+ lastValue = None
+ while time() - start_time < timeout:
+ if catch_assertions:
+ try:
+ lastValue = condition()
+ except AssertionError as e:
+ lastValue = e
+ else:
+ lastValue = condition()
+ if lastValue is True:
+ return
+ sleep(0.01)
+ if isinstance(lastValue, AssertionError):
+ raise lastValue
+ else:
+ raise AssertionError(
+ "Test failed due to timeout after %g sec, with last condition returning: %s"
+ % (timeout, lastValue))
+
+
+class StreamingKMeansTest(MLLibStreamingTestCase):
+ def test_model_params(self):
+ """Test that the model params are set correctly"""
+ stkm = StreamingKMeans()
+ stkm.setK(5).setDecayFactor(0.0)
+ self.assertEqual(stkm._k, 5)
+ self.assertEqual(stkm._decayFactor, 0.0)
+
+ # Model not set yet.
+ self.assertIsNone(stkm.latestModel())
+ self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0])
+
+ stkm.setInitialCenters(
+ centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0])
+ self.assertEqual(
+ stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]])
+ self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0])
+
+ def test_accuracy_for_single_center(self):
+ """Test that parameters obtained are correct for a single center."""
+ centers, batches = self.streamingKMeansDataGenerator(
+ batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0)
+ stkm = StreamingKMeans(1)
+ stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.])
+ input_stream = self.ssc.queueStream(
+ [self.sc.parallelize(batch, 1) for batch in batches])
+ stkm.trainOn(input_stream)
+
+ self.ssc.start()
+
+ def condition():
+ self.assertEqual(stkm.latestModel().clusterWeights, [25.0])
+ return True
+ self._eventually(condition, catch_assertions=True)
+
+ realCenters = array_sum(array(centers), axis=0)
+ for i in range(5):
+ modelCenters = stkm.latestModel().centers[0][i]
+ self.assertAlmostEqual(centers[0][i], modelCenters, 1)
+ self.assertAlmostEqual(realCenters[i], modelCenters, 1)
+
+ def streamingKMeansDataGenerator(self, batches, numPoints,
+ k, d, r, seed, centers=None):
+ rng = random.RandomState(seed)
+
+ # Generate centers.
+ centers = [rng.randn(d) for i in range(k)]
+
+ return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d))
+ for j in range(numPoints)]
+ for i in range(batches)]
+
+ def test_trainOn_model(self):
+ """Test the model on toy data with four clusters."""
+ stkm = StreamingKMeans()
+ initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]]
+ stkm.setInitialCenters(
+ centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0])
+
+ # Create a toy dataset by setting a tiny offset for each point.
+ offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]]
+ batches = []
+ for offset in offsets:
+ batches.append([[offset[0] + center[0], offset[1] + center[1]]
+ for center in initCenters])
+
+ batches = [self.sc.parallelize(batch, 1) for batch in batches]
+ input_stream = self.ssc.queueStream(batches)
+ stkm.trainOn(input_stream)
+ self.ssc.start()
+
+ # Give enough time to train the model.
+ def condition():
+ finalModel = stkm.latestModel()
+ self.assertTrue(all(finalModel.centers == array(initCenters)))
+ self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0])
+ return True
+ self._eventually(condition, catch_assertions=True)
+
+ def test_predictOn_model(self):
+ """Test that the model predicts correctly on toy data."""
+ stkm = StreamingKMeans()
+ stkm._model = StreamingKMeansModel(
+ clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]],
+ clusterWeights=[1.0, 1.0, 1.0, 1.0])
+
+ predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]]
+ predict_data = [self.sc.parallelize(batch, 1) for batch in predict_data]
+ predict_stream = self.ssc.queueStream(predict_data)
+ predict_val = stkm.predictOn(predict_stream)
+
+ result = []
+
+ def update(rdd):
+ rdd_collect = rdd.collect()
+ if rdd_collect:
+ result.append(rdd_collect)
+
+ predict_val.foreachRDD(update)
+ self.ssc.start()
+
+ def condition():
+ self.assertEqual(result, [[0], [1], [2], [3]])
+ return True
+
+ self._eventually(condition, catch_assertions=True)
+
+ @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark")
+ def test_trainOn_predictOn(self):
+ """Test that prediction happens on the updated model."""
+ stkm = StreamingKMeans(decayFactor=0.0, k=2)
+ stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0])
+
+ # Since decay factor is set to zero, once the first batch
+ # is passed the clusterCenters are updated to [-0.5, 0.7]
+ # which causes 0.2 & 0.3 to be classified as 1, even though the
+ # classification based in the initial model would have been 0
+ # proving that the model is updated.
+ batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]]
+ batches = [self.sc.parallelize(batch) for batch in batches]
+ input_stream = self.ssc.queueStream(batches)
+ predict_results = []
+
+ def collect(rdd):
+ rdd_collect = rdd.collect()
+ if rdd_collect:
+ predict_results.append(rdd_collect)
+
+ stkm.trainOn(input_stream)
+ predict_stream = stkm.predictOn(input_stream)
+ predict_stream.foreachRDD(collect)
+
+ self.ssc.start()
+
+ def condition():
+ self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
+ return True
+
+ self._eventually(condition, catch_assertions=True)
+
+
+class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase):
+
+ @staticmethod
+ def generateLogisticInput(offset, scale, nPoints, seed):
+ """
+ Generate 1 / (1 + exp(-x * scale + offset))
+
+ where,
+ x is randomnly distributed and the threshold
+ and labels for each sample in x is obtained from a random uniform
+ distribution.
+ """
+ rng = random.RandomState(seed)
+ x = rng.randn(nPoints)
+ sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset)))
+ y_p = rng.rand(nPoints)
+ cut_off = y_p <= sigmoid
+ y_p[cut_off] = 1.0
+ y_p[~cut_off] = 0.0
+ return [
+ LabeledPoint(y_p[i], Vectors.dense([x[i]]))
+ for i in range(nPoints)]
+
+ def test_parameter_accuracy(self):
+ """
+ Test that the final value of weights is close to the desired value.
+ """
+ input_batches = [
+ self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
+ for i in range(20)]
+ input_stream = self.ssc.queueStream(input_batches)
+
+ slr = StreamingLogisticRegressionWithSGD(
+ stepSize=0.2, numIterations=25)
+ slr.setInitialWeights([0.0])
+ slr.trainOn(input_stream)
+
+ self.ssc.start()
+
+ def condition():
+ rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5
+ self.assertAlmostEqual(rel, 0.1, 1)
+ return True
+
+ self._eventually(condition, catch_assertions=True)
+
+ def test_convergence(self):
+ """
+ Test that weights converge to the required value on toy data.
+ """
+ input_batches = [
+ self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
+ for i in range(20)]
+ input_stream = self.ssc.queueStream(input_batches)
+ models = []
+
+ slr = StreamingLogisticRegressionWithSGD(
+ stepSize=0.2, numIterations=25)
+ slr.setInitialWeights([0.0])
+ slr.trainOn(input_stream)
+ input_stream.foreachRDD(
+ lambda x: models.append(slr.latestModel().weights[0]))
+
+ self.ssc.start()
+
+ def condition():
+ self.assertEqual(len(models), len(input_batches))
+ return True
+
+ # We want all batches to finish for this test.
+ self._eventually(condition, 60.0, catch_assertions=True)
+
+ t_models = array(models)
+ diff = t_models[1:] - t_models[:-1]
+ # Test that weights improve with a small tolerance
+ self.assertTrue(all(diff >= -0.1))
+ self.assertTrue(array_sum(diff > 0) > 1)
+
+ @staticmethod
+ def calculate_accuracy_error(true, predicted):
+ return sum(abs(array(true) - array(predicted))) / len(true)
+
+ def test_predictions(self):
+ """Test predicted values on a toy model."""
+ input_batches = []
+ for i in range(20):
+ batch = self.sc.parallelize(
+ self.generateLogisticInput(0, 1.5, 100, 42 + i))
+ input_batches.append(batch.map(lambda x: (x.label, x.features)))
+ input_stream = self.ssc.queueStream(input_batches)
+
+ slr = StreamingLogisticRegressionWithSGD(
+ stepSize=0.2, numIterations=25)
+ slr.setInitialWeights([1.5])
+ predict_stream = slr.predictOnValues(input_stream)
+ true_predicted = []
+ predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect()))
+ self.ssc.start()
+
+ def condition():
+ self.assertEqual(len(true_predicted), len(input_batches))
+ return True
+
+ self._eventually(condition, catch_assertions=True)
+
+ # Test that the accuracy error is no more than 0.4 on each batch.
+ for batch in true_predicted:
+ true, predicted = zip(*batch)
+ self.assertTrue(
+ self.calculate_accuracy_error(true, predicted) < 0.4)
+
+ def test_training_and_prediction(self):
+ """Test that the model improves on toy data with no. of batches"""
+ input_batches = [
+ self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i))
+ for i in range(20)]
+ predict_batches = [
+ b.map(lambda lp: (lp.label, lp.features)) for b in input_batches]
+
+ slr = StreamingLogisticRegressionWithSGD(
+ stepSize=0.01, numIterations=25)
+ slr.setInitialWeights([-0.1])
+ errors = []
+
+ def collect_errors(rdd):
+ true, predicted = zip(*rdd.collect())
+ errors.append(self.calculate_accuracy_error(true, predicted))
+
+ true_predicted = []
+ input_stream = self.ssc.queueStream(input_batches)
+ predict_stream = self.ssc.queueStream(predict_batches)
+ slr.trainOn(input_stream)
+ ps = slr.predictOnValues(predict_stream)
+ ps.foreachRDD(lambda x: collect_errors(x))
+
+ self.ssc.start()
+
+ def condition():
+ # Test that the improvement in error is > 0.3
+ if len(errors) == len(predict_batches):
+ self.assertGreater(errors[1] - errors[-1], 0.3)
+ if len(errors) >= 3 and errors[1] - errors[-1] > 0.3:
+ return True
+ return "Latest errors: " + ", ".join(map(lambda x: str(x), errors))
+
+ self._eventually(condition)
+
+
+class StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
+
+ def assertArrayAlmostEqual(self, array1, array2, dec):
+ for i, j in array1, array2:
+ self.assertAlmostEqual(i, j, dec)
+
+ def test_parameter_accuracy(self):
+ """Test that coefs are predicted accurately by fitting on toy data."""
+
+ # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients
+ # (10, 10)
+ slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
+ slr.setInitialWeights([0.0, 0.0])
+ xMean = [0.0, 0.0]
+ xVariance = [1.0 / 3.0, 1.0 / 3.0]
+
+ # Create ten batches with 100 sample points in each.
+ batches = []
+ for i in range(10):
+ batch = LinearDataGenerator.generateLinearInput(
+ 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1)
+ batches.append(self.sc.parallelize(batch))
+
+ input_stream = self.ssc.queueStream(batches)
+ slr.trainOn(input_stream)
+ self.ssc.start()
+
+ def condition():
+ self.assertArrayAlmostEqual(
+ slr.latestModel().weights.array, [10., 10.], 1)
+ self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1)
+ return True
+
+ self._eventually(condition, catch_assertions=True)
+
+ def test_parameter_convergence(self):
+ """Test that the model parameters improve with streaming data."""
+ slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
+ slr.setInitialWeights([0.0])
+
+ # Create ten batches with 100 sample points in each.
+ batches = []
+ for i in range(10):
+ batch = LinearDataGenerator.generateLinearInput(
+ 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
+ batches.append(self.sc.parallelize(batch))
+
+ model_weights = []
+ input_stream = self.ssc.queueStream(batches)
+ input_stream.foreachRDD(
+ lambda x: model_weights.append(slr.latestModel().weights[0]))
+ slr.trainOn(input_stream)
+ self.ssc.start()
+
+ def condition():
+ self.assertEqual(len(model_weights), len(batches))
+ return True
+
+ # We want all batches to finish for this test.
+ self._eventually(condition, catch_assertions=True)
+
+ w = array(model_weights)
+ diff = w[1:] - w[:-1]
+ self.assertTrue(all(diff >= -0.1))
+
+ def test_prediction(self):
+ """Test prediction on a model with weights already set."""
+ # Create a model with initial Weights equal to coefs
+ slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
+ slr.setInitialWeights([10.0, 10.0])
+
+ # Create ten batches with 100 sample points in each.
+ batches = []
+ for i in range(10):
+ batch = LinearDataGenerator.generateLinearInput(
+ 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0],
+ 100, 42 + i, 0.1)
+ batches.append(
+ self.sc.parallelize(batch).map(lambda lp: (lp.label, lp.features)))
+
+ input_stream = self.ssc.queueStream(batches)
+ output_stream = slr.predictOnValues(input_stream)
+ samples = []
+ output_stream.foreachRDD(lambda x: samples.append(x.collect()))
+
+ self.ssc.start()
+
+ def condition():
+ self.assertEqual(len(samples), len(batches))
+ return True
+
+ # We want all batches to finish for this test.
+ self._eventually(condition, catch_assertions=True)
+
+ # Test that mean absolute error on each batch is less than 0.1
+ for batch in samples:
+ true, predicted = zip(*batch)
+ self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1)
+
+ def test_train_prediction(self):
+ """Test that error on test data improves as model is trained."""
+ slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25)
+ slr.setInitialWeights([0.0])
+
+ # Create ten batches with 100 sample points in each.
+ batches = []
+ for i in range(10):
+ batch = LinearDataGenerator.generateLinearInput(
+ 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1)
+ batches.append(self.sc.parallelize(batch))
+
+ predict_batches = [
+ b.map(lambda lp: (lp.label, lp.features)) for b in batches]
+ errors = []
+
+ def func(rdd):
+ true, predicted = zip(*rdd.collect())
+ errors.append(mean(abs(true) - abs(predicted)))
+
+ input_stream = self.ssc.queueStream(batches)
+ output_stream = self.ssc.queueStream(predict_batches)
+ slr.trainOn(input_stream)
+ output_stream = slr.predictOnValues(output_stream)
+ output_stream.foreachRDD(func)
+ self.ssc.start()
+
+ def condition():
+ if len(errors) == len(predict_batches):
+ self.assertGreater(errors[1] - errors[-1], 2)
+ if len(errors) >= 3 and errors[1] - errors[-1] > 2:
+ return True
+ return "Latest errors: " + ", ".join(map(lambda x: str(x), errors))
+
+ self._eventually(condition)
+
+
+if __name__ == "__main__":
+ from pyspark.mllib.tests.test_streaming_algorithms import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/mllib/tests/test_util.py b/python/pyspark/mllib/tests/test_util.py
new file mode 100644
index 0000000000000..e95716278f122
--- /dev/null
+++ b/python/pyspark/mllib/tests/test_util.py
@@ -0,0 +1,104 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+import unittest
+
+from pyspark.mllib.common import _to_java_object_rdd
+from pyspark.mllib.util import LinearDataGenerator
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.linalg import SparseVector, DenseVector, Vectors
+from pyspark.mllib.random import RandomRDDs
+from pyspark.testing.mllibutils import MLlibTestCase
+
+
+class MLUtilsTests(MLlibTestCase):
+ def test_append_bias(self):
+ data = [2.0, 2.0, 2.0]
+ ret = MLUtils.appendBias(data)
+ self.assertEqual(ret[3], 1.0)
+ self.assertEqual(type(ret), DenseVector)
+
+ def test_append_bias_with_vector(self):
+ data = Vectors.dense([2.0, 2.0, 2.0])
+ ret = MLUtils.appendBias(data)
+ self.assertEqual(ret[3], 1.0)
+ self.assertEqual(type(ret), DenseVector)
+
+ def test_append_bias_with_sp_vector(self):
+ data = Vectors.sparse(3, {0: 2.0, 2: 2.0})
+ expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0})
+ # Returned value must be SparseVector
+ ret = MLUtils.appendBias(data)
+ self.assertEqual(ret, expected)
+ self.assertEqual(type(ret), SparseVector)
+
+ def test_load_vectors(self):
+ import shutil
+ data = [
+ [1.0, 2.0, 3.0],
+ [1.0, 2.0, 3.0]
+ ]
+ temp_dir = tempfile.mkdtemp()
+ load_vectors_path = os.path.join(temp_dir, "test_load_vectors")
+ try:
+ self.sc.parallelize(data).saveAsTextFile(load_vectors_path)
+ ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path)
+ ret = ret_rdd.collect()
+ self.assertEqual(len(ret), 2)
+ self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0]))
+ self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0]))
+ except:
+ self.fail()
+ finally:
+ shutil.rmtree(load_vectors_path)
+
+
+class LinearDataGeneratorTests(MLlibTestCase):
+ def test_dim(self):
+ linear_data = LinearDataGenerator.generateLinearInput(
+ intercept=0.0, weights=[0.0, 0.0, 0.0],
+ xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33],
+ nPoints=4, seed=0, eps=0.1)
+ self.assertEqual(len(linear_data), 4)
+ for point in linear_data:
+ self.assertEqual(len(point.features), 3)
+
+ linear_data = LinearDataGenerator.generateLinearRDD(
+ sc=self.sc, nexamples=6, nfeatures=2, eps=0.1,
+ nParts=2, intercept=0.0).collect()
+ self.assertEqual(len(linear_data), 6)
+ for point in linear_data:
+ self.assertEqual(len(point.features), 2)
+
+
+class SerDeTest(MLlibTestCase):
+ def test_to_java_object_rdd(self): # SPARK-6660
+ data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0)
+ self.assertEqual(_to_java_object_rdd(data).count(), 10)
+
+
+if __name__ == "__main__":
+ from pyspark.mllib.tests.test_util import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 8815ade20e92c..e69ae768efd6d 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2354,7 +2354,7 @@ def countApproxDistinct(self, relativeSD=0.05):
The algorithm used is based on streamlib's implementation of
`"HyperLogLog in Practice: Algorithmic Engineering of a State
of The Art Cardinality Estimation Algorithm", available here
- `_.
+ `_.
:param relativeSD: Relative accuracy. Smaller values create
counters that require more space.
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 6202de36a479e..b8833a39078ba 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -257,7 +257,7 @@ def explain(self, extended=False):
>>> df.explain()
== Physical Plan ==
- Scan ExistingRDD[age#0,name#1]
+ *(1) Scan ExistingRDD[age#0,name#1]
>>> df.explain(True)
== Parsed Logical Plan ==
@@ -732,6 +732,11 @@ def repartitionByRange(self, numPartitions, *cols):
At least one partition-by expression must be specified.
When no explicit sort order is specified, "ascending nulls first" is assumed.
+ Note that due to performance reasons this method uses sampling to estimate the ranges.
+ Hence, the output may not be consistent, since sampling can return different values.
+ The sample size can be controlled by the config
+ `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
+
>>> df.repartitionByRange(2, "age").rdd.getNumPartitions()
2
>>> df.show()
@@ -1812,7 +1817,7 @@ def approxQuantile(self, col, probabilities, relativeError):
This method implements a variation of the Greenwald-Khanna
algorithm (with some speed optimizations). The algorithm was first
- present in [[http://dx.doi.org/10.1145/375663.375670
+ present in [[https://doi.org/10.1145/375663.375670
Space-efficient Online Computation of Quantile Summaries]]
by Greenwald and Khanna.
@@ -1934,7 +1939,7 @@ def freqItems(self, cols, support=None):
"""
Finding frequent items for columns, possibly with false positives. Using the
frequent element count algorithm described in
- "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou".
+ "https://doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou".
:func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases.
.. note:: This function is meant for exploratory data analysis, as we make no
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 519fb07352a73..9690c63f988fe 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2606,7 +2606,7 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))
-@since(2.4)
+@since(3.0)
def map_entries(col):
"""
Collection function: Returns an unordered array of all entries in the given map.
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 690b13072244b..1d2dd4d808930 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -177,7 +177,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
multiLine=None, allowUnquotedControlChars=None, lineSep=None, samplingRatio=None,
- dropFieldIfAllNull=None, encoding=None):
+ dropFieldIfAllNull=None, encoding=None, locale=None):
"""
Loads JSON files and returns the results as a :class:`DataFrame`.
@@ -249,6 +249,9 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
:param dropFieldIfAllNull: whether to ignore column of all null values or empty
array/struct during schema inference. If None is set, it
uses the default value, ``false``.
+ :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
+ it uses the default value, ``en-US``. For instance, ``locale`` is used while
+ parsing dates and timestamps.
>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
@@ -267,7 +270,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep,
- samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding)
+ samplingRatio=samplingRatio, dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding,
+ locale=locale)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
@@ -349,7 +353,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
- samplingRatio=None, enforceSchema=None, emptyValue=None):
+ samplingRatio=None, enforceSchema=None, emptyValue=None, locale=None, lineSep=None):
r"""Loads a CSV file and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -446,6 +450,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
If None is set, it uses the default value, ``1.0``.
:param emptyValue: sets the string representation of an empty value. If None is set, it uses
the default value, empty string.
+ :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
+ it uses the default value, ``en-US``. For instance, ``locale`` is used while
+ parsing dates and timestamps.
+ :param lineSep: defines the line separator that should be used for parsing. If None is
+ set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
+ Maximum length is 1 character.
>>> df = spark.read.csv('python/test_support/sql/ages.csv')
>>> df.dtypes
@@ -465,7 +475,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
- enforceSchema=enforceSchema, emptyValue=emptyValue)
+ enforceSchema=enforceSchema, emptyValue=emptyValue, locale=locale, lineSep=lineSep)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
@@ -861,7 +871,7 @@ def text(self, path, compression=None, lineSep=None):
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None,
timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None,
- charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None):
+ charToEscapeQuoteEscaping=None, encoding=None, emptyValue=None, lineSep=None):
r"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -915,6 +925,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
the default UTF-8 charset will be used.
:param emptyValue: sets the string representation of an empty value. If None is set, it uses
the default value, ``""``.
+ :param lineSep: defines the line separator that should be used for writing. If None is
+ set, it uses the default value, ``\\n``. Maximum length is 1 character.
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
@@ -925,7 +937,7 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping,
- encoding=encoding, emptyValue=emptyValue)
+ encoding=encoding, emptyValue=emptyValue, lineSep=lineSep)
self._jwrite.csv(path)
@since(1.5)
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index b18453b2a4f96..d92b0d5677e25 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -404,7 +404,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
- multiLine=None, allowUnquotedControlChars=None, lineSep=None):
+ multiLine=None, allowUnquotedControlChars=None, lineSep=None, locale=None,
+ dropFieldIfAllNull=None, encoding=None):
"""
Loads a JSON file stream and returns the results as a :class:`DataFrame`.
@@ -469,6 +470,16 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
including tab and line feed characters) or not.
:param lineSep: defines the line separator that should be used for parsing. If None is
set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
+ :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
+ it uses the default value, ``en-US``. For instance, ``locale`` is used while
+ parsing dates and timestamps.
+ :param dropFieldIfAllNull: whether to ignore column of all null values or empty
+ array/struct during schema inference. If None is set, it
+ uses the default value, ``false``.
+ :param encoding: allows to forcibly set one of standard basic or extended encoding for
+ the JSON files. For example UTF-16BE, UTF-32LE. If None is set,
+ the encoding of input JSON will be detected automatically
+ when the multiLine option is set to ``true``.
>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
>>> json_sdf.isStreaming
@@ -483,7 +494,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, multiLine=multiLine,
- allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep)
+ allowUnquotedControlChars=allowUnquotedControlChars, lineSep=lineSep, locale=locale,
+ dropFieldIfAllNull=dropFieldIfAllNull, encoding=encoding)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
@@ -564,7 +576,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
- enforceSchema=None, emptyValue=None):
+ enforceSchema=None, emptyValue=None, locale=None, lineSep=None):
r"""Loads a CSV file stream and returns the result as a :class:`DataFrame`.
This function will go through the input once to determine the input schema if
@@ -660,6 +672,12 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
different, ``\0`` otherwise..
:param emptyValue: sets the string representation of an empty value. If None is set, it uses
the default value, empty string.
+ :param locale: sets a locale as language tag in IETF BCP 47 format. If None is set,
+ it uses the default value, ``en-US``. For instance, ``locale`` is used while
+ parsing dates and timestamps.
+ :param lineSep: defines the line separator that should be used for parsing. If None is
+ set, it covers all ``\\r``, ``\\r\\n`` and ``\\n``.
+ Maximum length is 1 character.
>>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
>>> csv_sdf.isStreaming
@@ -677,7 +695,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema,
- emptyValue=emptyValue)
+ emptyValue=emptyValue, locale=locale, lineSep=lineSep)
if isinstance(path, basestring):
return self._df(self._jreader.csv(path))
else:
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
deleted file mode 100644
index 206d529454f3d..0000000000000
--- a/python/pyspark/sql/tests.py
+++ /dev/null
@@ -1,7109 +0,0 @@
-# -*- encoding: utf-8 -*-
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-Unit tests for pyspark.sql; additional tests are implemented as doctests in
-individual modules.
-"""
-import os
-import sys
-import subprocess
-import pydoc
-import shutil
-import tempfile
-import threading
-import pickle
-import functools
-import time
-import datetime
-import array
-import ctypes
-import warnings
-import py4j
-from contextlib import contextmanager
-import unishark
-
-if sys.version_info[:2] <= (2, 6):
- try:
- import unittest2 as unittest
- except ImportError:
- sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
- sys.exit(1)
-else:
- import unittest
-
-from pyspark.util import _exception_message
-
-_pandas_requirement_message = None
-try:
- from pyspark.sql.utils import require_minimum_pandas_version
- require_minimum_pandas_version()
-except ImportError as e:
- # If Pandas version requirement is not satisfied, skip related tests.
- _pandas_requirement_message = _exception_message(e)
-
-_pyarrow_requirement_message = None
-try:
- from pyspark.sql.utils import require_minimum_pyarrow_version
- require_minimum_pyarrow_version()
-except ImportError as e:
- # If Arrow version requirement is not satisfied, skip related tests.
- _pyarrow_requirement_message = _exception_message(e)
-
-_test_not_compiled_message = None
-try:
- from pyspark.sql.utils import require_test_compiled
- require_test_compiled()
-except Exception as e:
- _test_not_compiled_message = _exception_message(e)
-
-_have_pandas = _pandas_requirement_message is None
-_have_pyarrow = _pyarrow_requirement_message is None
-_test_compiled = _test_not_compiled_message is None
-
-from pyspark import SparkConf, SparkContext
-from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
-from pyspark.sql.types import *
-from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
-from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings
-from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
-from pyspark.sql.types import _merge_type
-from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests
-from pyspark.sql.functions import UserDefinedFunction, sha2, lit
-from pyspark.sql.window import Window
-from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
-
-
-class UTCOffsetTimezone(datetime.tzinfo):
- """
- Specifies timezone in UTC offset
- """
-
- def __init__(self, offset=0):
- self.ZERO = datetime.timedelta(hours=offset)
-
- def utcoffset(self, dt):
- return self.ZERO
-
- def dst(self, dt):
- return self.ZERO
-
-
-class ExamplePointUDT(UserDefinedType):
- """
- User-defined type (UDT) for ExamplePoint.
- """
-
- @classmethod
- def sqlType(self):
- return ArrayType(DoubleType(), False)
-
- @classmethod
- def module(cls):
- return 'pyspark.sql.tests'
-
- @classmethod
- def scalaUDT(cls):
- return 'org.apache.spark.sql.test.ExamplePointUDT'
-
- def serialize(self, obj):
- return [obj.x, obj.y]
-
- def deserialize(self, datum):
- return ExamplePoint(datum[0], datum[1])
-
-
-class ExamplePoint:
- """
- An example class to demonstrate UDT in Scala, Java, and Python.
- """
-
- __UDT__ = ExamplePointUDT()
-
- def __init__(self, x, y):
- self.x = x
- self.y = y
-
- def __repr__(self):
- return "ExamplePoint(%s,%s)" % (self.x, self.y)
-
- def __str__(self):
- return "(%s,%s)" % (self.x, self.y)
-
- def __eq__(self, other):
- return isinstance(other, self.__class__) and \
- other.x == self.x and other.y == self.y
-
-
-class PythonOnlyUDT(UserDefinedType):
- """
- User-defined type (UDT) for ExamplePoint.
- """
-
- @classmethod
- def sqlType(self):
- return ArrayType(DoubleType(), False)
-
- @classmethod
- def module(cls):
- return '__main__'
-
- def serialize(self, obj):
- return [obj.x, obj.y]
-
- def deserialize(self, datum):
- return PythonOnlyPoint(datum[0], datum[1])
-
- @staticmethod
- def foo():
- pass
-
- @property
- def props(self):
- return {}
-
-
-class PythonOnlyPoint(ExamplePoint):
- """
- An example class to demonstrate UDT in only Python
- """
- __UDT__ = PythonOnlyUDT()
-
-
-class MyObject(object):
- def __init__(self, key, value):
- self.key = key
- self.value = value
-
-
-class SQLTestUtils(object):
- """
- This util assumes the instance of this to have 'spark' attribute, having a spark session.
- It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
- the implementation of this class has 'spark' attribute.
- """
-
- @contextmanager
- def sql_conf(self, pairs):
- """
- A convenient context manager to test some configuration specific logic. This sets
- `value` to the configuration `key` and then restores it back when it exits.
- """
- assert isinstance(pairs, dict), "pairs should be a dictionary."
- assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
-
- keys = pairs.keys()
- new_values = pairs.values()
- old_values = [self.spark.conf.get(key, None) for key in keys]
- for key, new_value in zip(keys, new_values):
- self.spark.conf.set(key, new_value)
- try:
- yield
- finally:
- for key, old_value in zip(keys, old_values):
- if old_value is None:
- self.spark.conf.unset(key)
- else:
- self.spark.conf.set(key, old_value)
-
- @contextmanager
- def database(self, *databases):
- """
- A convenient context manager to test with some specific databases. This drops the given
- databases if exist and sets current database to "default" when it exits.
- """
- assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
-
- try:
- yield
- finally:
- for db in databases:
- self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
- self.spark.catalog.setCurrentDatabase("default")
-
- @contextmanager
- def table(self, *tables):
- """
- A convenient context manager to test with some specific tables. This drops the given tables
- if exist when it exits.
- """
- assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
-
- try:
- yield
- finally:
- for t in tables:
- self.spark.sql("DROP TABLE IF EXISTS %s" % t)
-
- @contextmanager
- def tempView(self, *views):
- """
- A convenient context manager to test with some specific views. This drops the given views
- if exist when it exits.
- """
- assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
-
- try:
- yield
- finally:
- for v in views:
- self.spark.catalog.dropTempView(v)
-
- @contextmanager
- def function(self, *functions):
- """
- A convenient context manager to test with some specific functions. This drops the given
- functions if exist when it exits.
- """
- assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
-
- try:
- yield
- finally:
- for f in functions:
- self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
-
-
-class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
- @classmethod
- def setUpClass(cls):
- super(ReusedSQLTestCase, cls).setUpClass()
- cls.spark = SparkSession(cls.sc)
-
- @classmethod
- def tearDownClass(cls):
- super(ReusedSQLTestCase, cls).tearDownClass()
- cls.spark.stop()
-
- def assertPandasEqual(self, expected, result):
- msg = ("DataFrames are not equal: " +
- "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
- "\n\nResult:\n%s\n%s" % (result, result.dtypes))
- self.assertTrue(expected.equals(result), msg=msg)
-
-
-class DataTypeTests(unittest.TestCase):
- # regression test for SPARK-6055
- def test_data_type_eq(self):
- lt = LongType()
- lt2 = pickle.loads(pickle.dumps(LongType()))
- self.assertEqual(lt, lt2)
-
- # regression test for SPARK-7978
- def test_decimal_type(self):
- t1 = DecimalType()
- t2 = DecimalType(10, 2)
- self.assertTrue(t2 is not t1)
- self.assertNotEqual(t1, t2)
- t3 = DecimalType(8)
- self.assertNotEqual(t2, t3)
-
- # regression test for SPARK-10392
- def test_datetype_equal_zero(self):
- dt = DateType()
- self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
-
- # regression test for SPARK-17035
- def test_timestamp_microsecond(self):
- tst = TimestampType()
- self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999)
-
- def test_empty_row(self):
- row = Row()
- self.assertEqual(len(row), 0)
-
- def test_struct_field_type_name(self):
- struct_field = StructField("a", IntegerType())
- self.assertRaises(TypeError, struct_field.typeName)
-
- def test_invalid_create_row(self):
- row_class = Row("c1", "c2")
- self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
-
-
-class SparkSessionBuilderTests(unittest.TestCase):
-
- def test_create_spark_context_first_then_spark_session(self):
- sc = None
- session = None
- try:
- conf = SparkConf().set("key1", "value1")
- sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf)
- session = SparkSession.builder.config("key2", "value2").getOrCreate()
-
- self.assertEqual(session.conf.get("key1"), "value1")
- self.assertEqual(session.conf.get("key2"), "value2")
- self.assertEqual(session.sparkContext, sc)
-
- self.assertFalse(sc.getConf().contains("key2"))
- self.assertEqual(sc.getConf().get("key1"), "value1")
- finally:
- if session is not None:
- session.stop()
- if sc is not None:
- sc.stop()
-
- def test_another_spark_session(self):
- session1 = None
- session2 = None
- try:
- session1 = SparkSession.builder.config("key1", "value1").getOrCreate()
- session2 = SparkSession.builder.config("key2", "value2").getOrCreate()
-
- self.assertEqual(session1.conf.get("key1"), "value1")
- self.assertEqual(session2.conf.get("key1"), "value1")
- self.assertEqual(session1.conf.get("key2"), "value2")
- self.assertEqual(session2.conf.get("key2"), "value2")
- self.assertEqual(session1.sparkContext, session2.sparkContext)
-
- self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1")
- self.assertFalse(session1.sparkContext.getConf().contains("key2"))
- finally:
- if session1 is not None:
- session1.stop()
- if session2 is not None:
- session2.stop()
-
-
-class SQLTests(ReusedSQLTestCase):
-
- @classmethod
- def setUpClass(cls):
- ReusedSQLTestCase.setUpClass()
- cls.spark.catalog._reset()
- cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
- os.unlink(cls.tempdir.name)
- cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
- cls.df = cls.spark.createDataFrame(cls.testData)
-
- @classmethod
- def tearDownClass(cls):
- ReusedSQLTestCase.tearDownClass()
- shutil.rmtree(cls.tempdir.name, ignore_errors=True)
-
- def test_sqlcontext_reuses_sparksession(self):
- sqlContext1 = SQLContext(self.sc)
- sqlContext2 = SQLContext(self.sc)
- self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)
-
- def test_row_should_be_read_only(self):
- row = Row(a=1, b=2)
- self.assertEqual(1, row.a)
-
- def foo():
- row.a = 3
- self.assertRaises(Exception, foo)
-
- row2 = self.spark.range(10).first()
- self.assertEqual(0, row2.id)
-
- def foo2():
- row2.id = 2
- self.assertRaises(Exception, foo2)
-
- def test_range(self):
- self.assertEqual(self.spark.range(1, 1).count(), 0)
- self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
- self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2)
- self.assertEqual(self.spark.range(-2).count(), 0)
- self.assertEqual(self.spark.range(3).count(), 3)
-
- def test_duplicated_column_names(self):
- df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
- row = df.select('*').first()
- self.assertEqual(1, row[0])
- self.assertEqual(2, row[1])
- self.assertEqual("Row(c=1, c=2)", str(row))
- # Cannot access columns
- self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
- self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
- self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
-
- def test_column_name_encoding(self):
- """Ensure that created columns has `str` type consistently."""
- columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns
- self.assertEqual(columns, ['name', 'age'])
- self.assertTrue(isinstance(columns[0], str))
- self.assertTrue(isinstance(columns[1], str))
-
- def test_explode(self):
- from pyspark.sql.functions import explode, explode_outer, posexplode_outer
- d = [
- Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
- Row(a=1, intlist=[], mapfield={}),
- Row(a=1, intlist=None, mapfield=None),
- ]
- rdd = self.sc.parallelize(d)
- data = self.spark.createDataFrame(rdd)
-
- result = data.select(explode(data.intlist).alias("a")).select("a").collect()
- self.assertEqual(result[0][0], 1)
- self.assertEqual(result[1][0], 2)
- self.assertEqual(result[2][0], 3)
-
- result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
- self.assertEqual(result[0][0], "a")
- self.assertEqual(result[0][1], "b")
-
- result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
- self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])
-
- result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
- self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)])
-
- result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
- self.assertEqual(result, [1, 2, 3, None, None])
-
- result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
- self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])
-
- def test_and_in_expression(self):
- self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
- self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
- self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count())
- self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2")
- self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count())
- self.assertRaises(ValueError, lambda: not self.df.key == 1)
-
- def test_udf_with_callable(self):
- d = [Row(number=i, squared=i**2) for i in range(10)]
- rdd = self.sc.parallelize(d)
- data = self.spark.createDataFrame(rdd)
-
- class PlusFour:
- def __call__(self, col):
- if col is not None:
- return col + 4
-
- call = PlusFour()
- pudf = UserDefinedFunction(call, LongType())
- res = data.select(pudf(data['number']).alias('plus_four'))
- self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
-
- def test_udf_with_partial_function(self):
- d = [Row(number=i, squared=i**2) for i in range(10)]
- rdd = self.sc.parallelize(d)
- data = self.spark.createDataFrame(rdd)
-
- def some_func(col, param):
- if col is not None:
- return col + param
-
- pfunc = functools.partial(some_func, param=4)
- pudf = UserDefinedFunction(pfunc, LongType())
- res = data.select(pudf(data['number']).alias('plus_four'))
- self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
-
- def test_udf(self):
- self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
- [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
- self.assertEqual(row[0], 5)
-
- # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
- sqlContext = self.spark._wrapped
- sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
- [row] = sqlContext.sql("SELECT oneArg('test')").collect()
- self.assertEqual(row[0], 4)
-
- def test_udf2(self):
- with self.tempView("test"):
- self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
- self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
- .createOrReplaceTempView("test")
- [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
- self.assertEqual(4, res[0])
-
- def test_udf3(self):
- two_args = self.spark.catalog.registerFunction(
- "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y))
- self.assertEqual(two_args.deterministic, True)
- [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
- self.assertEqual(row[0], u'5')
-
- def test_udf_registration_return_type_none(self):
- two_args = self.spark.catalog.registerFunction(
- "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None)
- self.assertEqual(two_args.deterministic, True)
- [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
- self.assertEqual(row[0], 5)
-
- def test_udf_registration_return_type_not_none(self):
- with QuietTest(self.sc):
- with self.assertRaisesRegexp(TypeError, "Invalid returnType"):
- self.spark.catalog.registerFunction(
- "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType())
-
- def test_nondeterministic_udf(self):
- # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
- from pyspark.sql.functions import udf
- import random
- udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
- self.assertEqual(udf_random_col.deterministic, False)
- df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
- udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
- [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
- self.assertEqual(row[0] + 10, row[1])
-
- def test_nondeterministic_udf2(self):
- import random
- from pyspark.sql.functions import udf
- random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
- self.assertEqual(random_udf.deterministic, False)
- random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
- self.assertEqual(random_udf1.deterministic, False)
- [row] = self.spark.sql("SELECT randInt()").collect()
- self.assertEqual(row[0], 6)
- [row] = self.spark.range(1).select(random_udf1()).collect()
- self.assertEqual(row[0], 6)
- [row] = self.spark.range(1).select(random_udf()).collect()
- self.assertEqual(row[0], 6)
- # render_doc() reproduces the help() exception without printing output
- pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
- pydoc.render_doc(random_udf)
- pydoc.render_doc(random_udf1)
- pydoc.render_doc(udf(lambda x: x).asNondeterministic)
-
- def test_nondeterministic_udf3(self):
- # regression test for SPARK-23233
- from pyspark.sql.functions import udf
- f = udf(lambda x: x)
- # Here we cache the JVM UDF instance.
- self.spark.range(1).select(f("id"))
- # This should reset the cache to set the deterministic status correctly.
- f = f.asNondeterministic()
- # Check the deterministic status of udf.
- df = self.spark.range(1).select(f("id"))
- deterministic = df._jdf.logicalPlan().projectList().head().deterministic()
- self.assertFalse(deterministic)
-
- def test_nondeterministic_udf_in_aggregate(self):
- from pyspark.sql.functions import udf, sum
- import random
- udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
- df = self.spark.range(10)
-
- with QuietTest(self.sc):
- with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
- df.groupby('id').agg(sum(udf_random_col())).collect()
- with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
- df.agg(sum(udf_random_col())).collect()
-
- def test_chained_udf(self):
- self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
- [row] = self.spark.sql("SELECT double(1)").collect()
- self.assertEqual(row[0], 2)
- [row] = self.spark.sql("SELECT double(double(1))").collect()
- self.assertEqual(row[0], 4)
- [row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
- self.assertEqual(row[0], 6)
-
- def test_single_udf_with_repeated_argument(self):
- # regression test for SPARK-20685
- self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
- row = self.spark.sql("SELECT add(1, 1)").first()
- self.assertEqual(tuple(row), (2, ))
-
- def test_multiple_udfs(self):
- self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
- [row] = self.spark.sql("SELECT double(1), double(2)").collect()
- self.assertEqual(tuple(row), (2, 4))
- [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
- self.assertEqual(tuple(row), (4, 12))
- self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
- [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
- self.assertEqual(tuple(row), (6, 5))
-
- def test_udf_in_filter_on_top_of_outer_join(self):
- from pyspark.sql.functions import udf
- left = self.spark.createDataFrame([Row(a=1)])
- right = self.spark.createDataFrame([Row(a=1)])
- df = left.join(right, on='a', how='left_outer')
- df = df.withColumn('b', udf(lambda x: 'x')(df.a))
- self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
-
- def test_udf_in_filter_on_top_of_join(self):
- # regression test for SPARK-18589
- from pyspark.sql.functions import udf
- left = self.spark.createDataFrame([Row(a=1)])
- right = self.spark.createDataFrame([Row(b=1)])
- f = udf(lambda a, b: a == b, BooleanType())
- df = left.crossJoin(right).filter(f("a", "b"))
- self.assertEqual(df.collect(), [Row(a=1, b=1)])
-
- def test_udf_in_join_condition(self):
- # regression test for SPARK-25314
- from pyspark.sql.functions import udf
- left = self.spark.createDataFrame([Row(a=1)])
- right = self.spark.createDataFrame([Row(b=1)])
- f = udf(lambda a, b: a == b, BooleanType())
- df = left.join(right, f("a", "b"))
- with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
- df.collect()
- with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
- self.assertEqual(df.collect(), [Row(a=1, b=1)])
-
- def test_udf_in_left_semi_join_condition(self):
- # regression test for SPARK-25314
- from pyspark.sql.functions import udf
- left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
- right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
- f = udf(lambda a, b: a == b, BooleanType())
- df = left.join(right, f("a", "b"), "leftsemi")
- with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
- df.collect()
- with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
- self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
-
- def test_udf_and_common_filter_in_join_condition(self):
- # regression test for SPARK-25314
- # test the complex scenario with both udf and common filter
- from pyspark.sql.functions import udf
- left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
- right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
- f = udf(lambda a, b: a == b, BooleanType())
- df = left.join(right, [f("a", "b"), left.a1 == right.b1])
- # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
- self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
-
- def test_udf_and_common_filter_in_left_semi_join_condition(self):
- # regression test for SPARK-25314
- # test the complex scenario with both udf and common filter
- from pyspark.sql.functions import udf
- left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
- right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
- f = udf(lambda a, b: a == b, BooleanType())
- df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi")
- # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
- self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
-
- def test_udf_not_supported_in_join_condition(self):
- # regression test for SPARK-25314
- # test python udf is not supported in join type besides left_semi and inner join.
- from pyspark.sql.functions import udf
- left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
- right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
- f = udf(lambda a, b: a == b, BooleanType())
-
- def runWithJoinType(join_type, type_string):
- with self.assertRaisesRegexp(
- AnalysisException,
- 'Using PythonUDF.*%s is not supported.' % type_string):
- left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect()
- runWithJoinType("full", "FullOuter")
- runWithJoinType("left", "LeftOuter")
- runWithJoinType("right", "RightOuter")
- runWithJoinType("leftanti", "LeftAnti")
-
- def test_udf_without_arguments(self):
- self.spark.catalog.registerFunction("foo", lambda: "bar")
- [row] = self.spark.sql("SELECT foo()").collect()
- self.assertEqual(row[0], "bar")
-
- def test_udf_with_array_type(self):
- with self.tempView("test"):
- d = [Row(l=list(range(3)), d={"key": list(range(5))})]
- rdd = self.sc.parallelize(d)
- self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
- self.spark.catalog.registerFunction(
- "copylist", lambda l: list(l), ArrayType(IntegerType()))
- self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType())
- [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect()
- self.assertEqual(list(range(3)), l1)
- self.assertEqual(1, l2)
-
- def test_broadcast_in_udf(self):
- bar = {"a": "aa", "b": "bb", "c": "abc"}
- foo = self.sc.broadcast(bar)
- self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
- [res] = self.spark.sql("SELECT MYUDF('c')").collect()
- self.assertEqual("abc", res[0])
- [res] = self.spark.sql("SELECT MYUDF('')").collect()
- self.assertEqual("", res[0])
-
- def test_udf_with_filter_function(self):
- df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
- from pyspark.sql.functions import udf, col
- from pyspark.sql.types import BooleanType
-
- my_filter = udf(lambda a: a < 2, BooleanType())
- sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
- self.assertEqual(sel.collect(), [Row(key=1, value='1')])
-
- def test_udf_with_aggregate_function(self):
- df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
- from pyspark.sql.functions import udf, col, sum
- from pyspark.sql.types import BooleanType
-
- my_filter = udf(lambda a: a == 1, BooleanType())
- sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
- self.assertEqual(sel.collect(), [Row(key=1)])
-
- my_copy = udf(lambda x: x, IntegerType())
- my_add = udf(lambda a, b: int(a + b), IntegerType())
- my_strlen = udf(lambda x: len(x), IntegerType())
- sel = df.groupBy(my_copy(col("key")).alias("k"))\
- .agg(sum(my_strlen(col("value"))).alias("s"))\
- .select(my_add(col("k"), col("s")).alias("t"))
- self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
-
- def test_udf_in_generate(self):
- from pyspark.sql.functions import udf, explode
- df = self.spark.range(5)
- f = udf(lambda x: list(range(x)), ArrayType(LongType()))
- row = df.select(explode(f(*df))).groupBy().sum().first()
- self.assertEqual(row[0], 10)
-
- df = self.spark.range(3)
- res = df.select("id", explode(f(df.id))).collect()
- self.assertEqual(res[0][0], 1)
- self.assertEqual(res[0][1], 0)
- self.assertEqual(res[1][0], 2)
- self.assertEqual(res[1][1], 0)
- self.assertEqual(res[2][0], 2)
- self.assertEqual(res[2][1], 1)
-
- range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType()))
- res = df.select("id", explode(range_udf(df.id))).collect()
- self.assertEqual(res[0][0], 0)
- self.assertEqual(res[0][1], -1)
- self.assertEqual(res[1][0], 0)
- self.assertEqual(res[1][1], 0)
- self.assertEqual(res[2][0], 1)
- self.assertEqual(res[2][1], 0)
- self.assertEqual(res[3][0], 1)
- self.assertEqual(res[3][1], 1)
-
- def test_udf_with_order_by_and_limit(self):
- from pyspark.sql.functions import udf
- my_copy = udf(lambda x: x, IntegerType())
- df = self.spark.range(10).orderBy("id")
- res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
- res.explain(True)
- self.assertEqual(res.collect(), [Row(id=0, copy=0)])
-
- def test_udf_registration_returns_udf(self):
- df = self.spark.range(10)
- add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
-
- self.assertListEqual(
- df.selectExpr("add_three(id) AS plus_three").collect(),
- df.select(add_three("id").alias("plus_three")).collect()
- )
-
- # This is to check if a 'SQLContext.udf' can call its alias.
- sqlContext = self.spark._wrapped
- add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
-
- self.assertListEqual(
- df.selectExpr("add_four(id) AS plus_four").collect(),
- df.select(add_four("id").alias("plus_four")).collect()
- )
-
- def test_non_existed_udf(self):
- spark = self.spark
- self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
- lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
-
- # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
- sqlContext = spark._wrapped
- self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
- lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
-
- def test_non_existed_udaf(self):
- spark = self.spark
- self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
- lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
-
- def test_linesep_text(self):
- df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",")
- expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'),
- Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'),
- Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'),
- Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')]
- self.assertEqual(df.collect(), expected)
-
- tpath = tempfile.mkdtemp()
- shutil.rmtree(tpath)
- try:
- df.write.text(tpath, lineSep="!")
- expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'),
- Row(value=u'Tom!30!"My name is Tom"'),
- Row(value=u'Hyukjin!25!"I am Hyukjin'),
- Row(value=u''), Row(value=u'I love Spark!"'),
- Row(value=u'!')]
- readback = self.spark.read.text(tpath)
- self.assertEqual(readback.collect(), expected)
- finally:
- shutil.rmtree(tpath)
-
- def test_multiline_json(self):
- people1 = self.spark.read.json("python/test_support/sql/people.json")
- people_array = self.spark.read.json("python/test_support/sql/people_array.json",
- multiLine=True)
- self.assertEqual(people1.collect(), people_array.collect())
-
- def test_encoding_json(self):
- people_array = self.spark.read\
- .json("python/test_support/sql/people_array_utf16le.json",
- multiLine=True, encoding="UTF-16LE")
- expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')]
- self.assertEqual(people_array.collect(), expected)
-
- def test_linesep_json(self):
- df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",")
- expected = [Row(_corrupt_record=None, name=u'Michael'),
- Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None),
- Row(_corrupt_record=u' "age":19}\n', name=None)]
- self.assertEqual(df.collect(), expected)
-
- tpath = tempfile.mkdtemp()
- shutil.rmtree(tpath)
- try:
- df = self.spark.read.json("python/test_support/sql/people.json")
- df.write.json(tpath, lineSep="!!")
- readback = self.spark.read.json(tpath, lineSep="!!")
- self.assertEqual(readback.collect(), df.collect())
- finally:
- shutil.rmtree(tpath)
-
- def test_multiline_csv(self):
- ages_newlines = self.spark.read.csv(
- "python/test_support/sql/ages_newlines.csv", multiLine=True)
- expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'),
- Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'),
- Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')]
- self.assertEqual(ages_newlines.collect(), expected)
-
- def test_ignorewhitespace_csv(self):
- tmpPath = tempfile.mkdtemp()
- shutil.rmtree(tmpPath)
- self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv(
- tmpPath,
- ignoreLeadingWhiteSpace=False,
- ignoreTrailingWhiteSpace=False)
-
- expected = [Row(value=u' a,b , c ')]
- readback = self.spark.read.text(tmpPath)
- self.assertEqual(readback.collect(), expected)
- shutil.rmtree(tmpPath)
-
- def test_read_multiple_orc_file(self):
- df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0",
- "python/test_support/sql/orc_partitioned/b=1/c=1"])
- self.assertEqual(2, df.count())
-
- def test_udf_with_input_file_name(self):
- from pyspark.sql.functions import udf, input_file_name
- sourceFile = udf(lambda path: path, StringType())
- filePath = "python/test_support/sql/people1.json"
- row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
- self.assertTrue(row[0].find("people1.json") != -1)
-
- def test_udf_with_input_file_name_for_hadooprdd(self):
- from pyspark.sql.functions import udf, input_file_name
-
- def filename(path):
- return path
-
- sameText = udf(filename, StringType())
-
- rdd = self.sc.textFile('python/test_support/sql/people.json')
- df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
- row = df.select(sameText(df['file'])).first()
- self.assertTrue(row[0].find("people.json") != -1)
-
- rdd2 = self.sc.newAPIHadoopFile(
- 'python/test_support/sql/people.json',
- 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
- 'org.apache.hadoop.io.LongWritable',
- 'org.apache.hadoop.io.Text')
-
- df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file'))
- row2 = df2.select(sameText(df2['file'])).first()
- self.assertTrue(row2[0].find("people.json") != -1)
-
- def test_udf_defers_judf_initialization(self):
- # This is separate of UDFInitializationTests
- # to avoid context initialization
- # when udf is called
-
- from pyspark.sql.functions import UserDefinedFunction
-
- f = UserDefinedFunction(lambda x: x, StringType())
-
- self.assertIsNone(
- f._judf_placeholder,
- "judf should not be initialized before the first call."
- )
-
- self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")
-
- self.assertIsNotNone(
- f._judf_placeholder,
- "judf should be initialized after UDF has been called."
- )
-
- def test_udf_with_string_return_type(self):
- from pyspark.sql.functions import UserDefinedFunction
-
- add_one = UserDefinedFunction(lambda x: x + 1, "integer")
- make_pair = UserDefinedFunction(lambda x: (-x, x), "struct")
- make_array = UserDefinedFunction(
- lambda x: [float(x) for x in range(x, x + 3)], "array")
-
- expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
- actual = (self.spark.range(1, 2).toDF("x")
- .select(add_one("x"), make_pair("x"), make_array("x"))
- .first())
-
- self.assertTupleEqual(expected, actual)
-
- def test_udf_shouldnt_accept_noncallable_object(self):
- from pyspark.sql.functions import UserDefinedFunction
-
- non_callable = None
- self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
-
- def test_udf_with_decorator(self):
- from pyspark.sql.functions import lit, udf
- from pyspark.sql.types import IntegerType, DoubleType
-
- @udf(IntegerType())
- def add_one(x):
- if x is not None:
- return x + 1
-
- @udf(returnType=DoubleType())
- def add_two(x):
- if x is not None:
- return float(x + 2)
-
- @udf
- def to_upper(x):
- if x is not None:
- return x.upper()
-
- @udf()
- def to_lower(x):
- if x is not None:
- return x.lower()
-
- @udf
- def substr(x, start, end):
- if x is not None:
- return x[start:end]
-
- @udf("long")
- def trunc(x):
- return int(x)
-
- @udf(returnType="double")
- def as_double(x):
- return float(x)
-
- df = (
- self.spark
- .createDataFrame(
- [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
- .select(
- add_one("one"), add_two("one"),
- to_upper("Foo"), to_lower("Foo"),
- substr("foobar", lit(0), lit(3)),
- trunc("float"), as_double("one")))
-
- self.assertListEqual(
- [tpe for _, tpe in df.dtypes],
- ["int", "double", "string", "string", "string", "bigint", "double"]
- )
-
- self.assertListEqual(
- list(df.first()),
- [2, 3.0, "FOO", "foo", "foo", 3, 1.0]
- )
-
- def test_udf_wrapper(self):
- from pyspark.sql.functions import udf
- from pyspark.sql.types import IntegerType
-
- def f(x):
- """Identity"""
- return x
-
- return_type = IntegerType()
- f_ = udf(f, return_type)
-
- self.assertTrue(f.__doc__ in f_.__doc__)
- self.assertEqual(f, f_.func)
- self.assertEqual(return_type, f_.returnType)
-
- class F(object):
- """Identity"""
- def __call__(self, x):
- return x
-
- f = F()
- return_type = IntegerType()
- f_ = udf(f, return_type)
-
- self.assertTrue(f.__doc__ in f_.__doc__)
- self.assertEqual(f, f_.func)
- self.assertEqual(return_type, f_.returnType)
-
- f = functools.partial(f, x=1)
- return_type = IntegerType()
- f_ = udf(f, return_type)
-
- self.assertTrue(f.__doc__ in f_.__doc__)
- self.assertEqual(f, f_.func)
- self.assertEqual(return_type, f_.returnType)
-
- def test_validate_column_types(self):
- from pyspark.sql.functions import udf, to_json
- from pyspark.sql.column import _to_java_column
-
- self.assertTrue("Column" in _to_java_column("a").getClass().toString())
- self.assertTrue("Column" in _to_java_column(u"a").getClass().toString())
- self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString())
-
- self.assertRaisesRegexp(
- TypeError,
- "Invalid argument, not a string or column",
- lambda: _to_java_column(1))
-
- class A():
- pass
-
- self.assertRaises(TypeError, lambda: _to_java_column(A()))
- self.assertRaises(TypeError, lambda: _to_java_column([]))
-
- self.assertRaisesRegexp(
- TypeError,
- "Invalid argument, not a string or column",
- lambda: udf(lambda x: x)(None))
- self.assertRaises(TypeError, lambda: to_json(1))
-
- def test_basic_functions(self):
- rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- df = self.spark.read.json(rdd)
- df.count()
- df.collect()
- df.schema
-
- # cache and checkpoint
- self.assertFalse(df.is_cached)
- df.persist()
- df.unpersist(True)
- df.cache()
- self.assertTrue(df.is_cached)
- self.assertEqual(2, df.count())
-
- with self.tempView("temp"):
- df.createOrReplaceTempView("temp")
- df = self.spark.sql("select foo from temp")
- df.count()
- df.collect()
-
- def test_apply_schema_to_row(self):
- df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
- df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
- self.assertEqual(df.collect(), df2.collect())
-
- rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- df3 = self.spark.createDataFrame(rdd, df.schema)
- self.assertEqual(10, df3.count())
-
- def test_infer_schema_to_local(self):
- input = [{"a": 1}, {"b": "coffee"}]
- rdd = self.sc.parallelize(input)
- df = self.spark.createDataFrame(input)
- df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
- self.assertEqual(df.schema, df2.schema)
-
- rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
- df3 = self.spark.createDataFrame(rdd, df.schema)
- self.assertEqual(10, df3.count())
-
- def test_apply_schema_to_dict_and_rows(self):
- schema = StructType().add("b", StringType()).add("a", IntegerType())
- input = [{"a": 1}, {"b": "coffee"}]
- rdd = self.sc.parallelize(input)
- for verify in [False, True]:
- df = self.spark.createDataFrame(input, schema, verifySchema=verify)
- df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
- self.assertEqual(df.schema, df2.schema)
-
- rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
- df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
- self.assertEqual(10, df3.count())
- input = [Row(a=x, b=str(x)) for x in range(10)]
- df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
- self.assertEqual(10, df4.count())
-
- def test_create_dataframe_schema_mismatch(self):
- input = [Row(a=1)]
- rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
- schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
- df = self.spark.createDataFrame(rdd, schema)
- self.assertRaises(Exception, lambda: df.show())
-
- def test_serialize_nested_array_and_map(self):
- d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
- rdd = self.sc.parallelize(d)
- df = self.spark.createDataFrame(rdd)
- row = df.head()
- self.assertEqual(1, len(row.l))
- self.assertEqual(1, row.l[0].a)
- self.assertEqual("2", row.d["key"].d)
-
- l = df.rdd.map(lambda x: x.l).first()
- self.assertEqual(1, len(l))
- self.assertEqual('s', l[0].b)
-
- d = df.rdd.map(lambda x: x.d).first()
- self.assertEqual(1, len(d))
- self.assertEqual(1.0, d["key"].c)
-
- row = df.rdd.map(lambda x: x.d["key"]).first()
- self.assertEqual(1.0, row.c)
- self.assertEqual("2", row.d)
-
- def test_infer_schema(self):
- d = [Row(l=[], d={}, s=None),
- Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
- rdd = self.sc.parallelize(d)
- df = self.spark.createDataFrame(rdd)
- self.assertEqual([], df.rdd.map(lambda r: r.l).first())
- self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
-
- with self.tempView("test"):
- df.createOrReplaceTempView("test")
- result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'")
- self.assertEqual(1, result.head()[0])
-
- df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
- self.assertEqual(df.schema, df2.schema)
- self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
- self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
-
- with self.tempView("test2"):
- df2.createOrReplaceTempView("test2")
- result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
- self.assertEqual(1, result.head()[0])
-
- def test_infer_schema_specification(self):
- from decimal import Decimal
-
- class A(object):
- def __init__(self):
- self.a = 1
-
- data = [
- True,
- 1,
- "a",
- u"a",
- datetime.date(1970, 1, 1),
- datetime.datetime(1970, 1, 1, 0, 0),
- 1.0,
- array.array("d", [1]),
- [1],
- (1, ),
- {"a": 1},
- bytearray(1),
- Decimal(1),
- Row(a=1),
- Row("a")(1),
- A(),
- ]
-
- df = self.spark.createDataFrame([data])
- actual = list(map(lambda x: x.dataType.simpleString(), df.schema))
- expected = [
- 'boolean',
- 'bigint',
- 'string',
- 'string',
- 'date',
- 'timestamp',
- 'double',
- 'array',
- 'array',
- 'struct<_1:bigint>',
- 'map',
- 'binary',
- 'decimal(38,18)',
- 'struct',
- 'struct',
- 'struct',
- ]
- self.assertEqual(actual, expected)
-
- actual = list(df.first())
- expected = [
- True,
- 1,
- 'a',
- u"a",
- datetime.date(1970, 1, 1),
- datetime.datetime(1970, 1, 1, 0, 0),
- 1.0,
- [1.0],
- [1],
- Row(_1=1),
- {"a": 1},
- bytearray(b'\x00'),
- Decimal('1.000000000000000000'),
- Row(a=1),
- Row(a=1),
- Row(a=1),
- ]
- self.assertEqual(actual, expected)
-
- def test_infer_schema_not_enough_names(self):
- df = self.spark.createDataFrame([["a", "b"]], ["col1"])
- self.assertEqual(df.columns, ['col1', '_2'])
-
- def test_infer_schema_fails(self):
- with self.assertRaisesRegexp(TypeError, 'field a'):
- self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
- schema=["a", "b"], samplingRatio=0.99)
-
- def test_infer_nested_schema(self):
- NestedRow = Row("f1", "f2")
- nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
- NestedRow([2, 3], {"row2": 2.0})])
- df = self.spark.createDataFrame(nestedRdd1)
- self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
-
- nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
- NestedRow([[2, 3], [3, 4]], [2, 3])])
- df = self.spark.createDataFrame(nestedRdd2)
- self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
-
- from collections import namedtuple
- CustomRow = namedtuple('CustomRow', 'field1 field2')
- rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
- CustomRow(field1=2, field2="row2"),
- CustomRow(field1=3, field2="row3")])
- df = self.spark.createDataFrame(rdd)
- self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
-
- def test_create_dataframe_from_dict_respects_schema(self):
- df = self.spark.createDataFrame([{'a': 1}], ["b"])
- self.assertEqual(df.columns, ['b'])
-
- def test_create_dataframe_from_objects(self):
- data = [MyObject(1, "1"), MyObject(2, "2")]
- df = self.spark.createDataFrame(data)
- self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
- self.assertEqual(df.first(), Row(key=1, value="1"))
-
- def test_select_null_literal(self):
- df = self.spark.sql("select null as col")
- self.assertEqual(Row(col=None), df.first())
-
- def test_apply_schema(self):
- from datetime import date, datetime
- rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
- date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
- {"a": 1}, (2,), [1, 2, 3], None)])
- schema = StructType([
- StructField("byte1", ByteType(), False),
- StructField("byte2", ByteType(), False),
- StructField("short1", ShortType(), False),
- StructField("short2", ShortType(), False),
- StructField("int1", IntegerType(), False),
- StructField("float1", FloatType(), False),
- StructField("date1", DateType(), False),
- StructField("time1", TimestampType(), False),
- StructField("map1", MapType(StringType(), IntegerType(), False), False),
- StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
- StructField("list1", ArrayType(ByteType(), False), False),
- StructField("null1", DoubleType(), True)])
- df = self.spark.createDataFrame(rdd, schema)
- results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1,
- x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
- r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
- datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
- self.assertEqual(r, results.first())
-
- with self.tempView("table2"):
- df.createOrReplaceTempView("table2")
- r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
- "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
- "float1 + 1.5 as float1 FROM table2").first()
-
- self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
-
- def test_struct_in_map(self):
- d = [Row(m={Row(i=1): Row(s="")})]
- df = self.sc.parallelize(d).toDF()
- k, v = list(df.head().m.items())[0]
- self.assertEqual(1, k.i)
- self.assertEqual("", v.s)
-
- def test_convert_row_to_dict(self):
- row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
- self.assertEqual(1, row.asDict()['l'][0].a)
- df = self.sc.parallelize([row]).toDF()
-
- with self.tempView("test"):
- df.createOrReplaceTempView("test")
- row = self.spark.sql("select l, d from test").head()
- self.assertEqual(1, row.asDict()["l"][0].a)
- self.assertEqual(1.0, row.asDict()['d']['key'].c)
-
- def test_udt(self):
- from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier
- from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
-
- def check_datatype(datatype):
- pickled = pickle.loads(pickle.dumps(datatype))
- assert datatype == pickled
- scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json())
- python_datatype = _parse_datatype_json_string(scala_datatype.json())
- assert datatype == python_datatype
-
- check_datatype(ExamplePointUDT())
- structtype_with_udt = StructType([StructField("label", DoubleType(), False),
- StructField("point", ExamplePointUDT(), False)])
- check_datatype(structtype_with_udt)
- p = ExamplePoint(1.0, 2.0)
- self.assertEqual(_infer_type(p), ExamplePointUDT())
- _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
- self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0]))
-
- check_datatype(PythonOnlyUDT())
- structtype_with_udt = StructType([StructField("label", DoubleType(), False),
- StructField("point", PythonOnlyUDT(), False)])
- check_datatype(structtype_with_udt)
- p = PythonOnlyPoint(1.0, 2.0)
- self.assertEqual(_infer_type(p), PythonOnlyUDT())
- _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
- self.assertRaises(
- ValueError,
- lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
-
- def test_simple_udt_in_df(self):
- schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
- df = self.spark.createDataFrame(
- [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
- schema=schema)
- df.collect()
-
- def test_nested_udt_in_df(self):
- schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
- df = self.spark.createDataFrame(
- [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
- schema=schema)
- df.collect()
-
- schema = StructType().add("key", LongType()).add("val",
- MapType(LongType(), PythonOnlyUDT()))
- df = self.spark.createDataFrame(
- [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
- schema=schema)
- df.collect()
-
- def test_complex_nested_udt_in_df(self):
- from pyspark.sql.functions import udf
-
- schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
- df = self.spark.createDataFrame(
- [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
- schema=schema)
- df.collect()
-
- gd = df.groupby("key").agg({"val": "collect_list"})
- gd.collect()
- udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
- gd.select(udf(*gd)).collect()
-
- def test_udt_with_none(self):
- df = self.spark.range(0, 10, 1, 1)
-
- def myudf(x):
- if x > 0:
- return PythonOnlyPoint(float(x), float(x))
-
- self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
- rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
- self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
-
- def test_nonparam_udf_with_aggregate(self):
- import pyspark.sql.functions as f
-
- df = self.spark.createDataFrame([(1, 2), (1, 2)])
- f_udf = f.udf(lambda: "const_str")
- rows = df.distinct().withColumn("a", f_udf()).collect()
- self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])
-
- def test_infer_schema_with_udt(self):
- from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
- row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df = self.spark.createDataFrame([row])
- schema = df.schema
- field = [f for f in schema.fields if f.name == "point"][0]
- self.assertEqual(type(field.dataType), ExamplePointUDT)
-
- with self.tempView("labeled_point"):
- df.createOrReplaceTempView("labeled_point")
- point = self.spark.sql("SELECT point FROM labeled_point").head().point
- self.assertEqual(point, ExamplePoint(1.0, 2.0))
-
- row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
- df = self.spark.createDataFrame([row])
- schema = df.schema
- field = [f for f in schema.fields if f.name == "point"][0]
- self.assertEqual(type(field.dataType), PythonOnlyUDT)
-
- with self.tempView("labeled_point"):
- df.createOrReplaceTempView("labeled_point")
- point = self.spark.sql("SELECT point FROM labeled_point").head().point
- self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
-
- def test_apply_schema_with_udt(self):
- from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
- row = (1.0, ExamplePoint(1.0, 2.0))
- schema = StructType([StructField("label", DoubleType(), False),
- StructField("point", ExamplePointUDT(), False)])
- df = self.spark.createDataFrame([row], schema)
- point = df.head().point
- self.assertEqual(point, ExamplePoint(1.0, 2.0))
-
- row = (1.0, PythonOnlyPoint(1.0, 2.0))
- schema = StructType([StructField("label", DoubleType(), False),
- StructField("point", PythonOnlyUDT(), False)])
- df = self.spark.createDataFrame([row], schema)
- point = df.head().point
- self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
-
- def test_udf_with_udt(self):
- from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
- row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df = self.spark.createDataFrame([row])
- self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
- udf = UserDefinedFunction(lambda p: p.y, DoubleType())
- self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
- udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
- self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
-
- row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
- df = self.spark.createDataFrame([row])
- self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
- udf = UserDefinedFunction(lambda p: p.y, DoubleType())
- self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
- udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
- self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
-
- def test_parquet_with_udt(self):
- from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
- row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df0 = self.spark.createDataFrame([row])
- output_dir = os.path.join(self.tempdir.name, "labeled_point")
- df0.write.parquet(output_dir)
- df1 = self.spark.read.parquet(output_dir)
- point = df1.head().point
- self.assertEqual(point, ExamplePoint(1.0, 2.0))
-
- row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
- df0 = self.spark.createDataFrame([row])
- df0.write.parquet(output_dir, mode='overwrite')
- df1 = self.spark.read.parquet(output_dir)
- point = df1.head().point
- self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
-
- def test_union_with_udt(self):
- from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
- row1 = (1.0, ExamplePoint(1.0, 2.0))
- row2 = (2.0, ExamplePoint(3.0, 4.0))
- schema = StructType([StructField("label", DoubleType(), False),
- StructField("point", ExamplePointUDT(), False)])
- df1 = self.spark.createDataFrame([row1], schema)
- df2 = self.spark.createDataFrame([row2], schema)
-
- result = df1.union(df2).orderBy("label").collect()
- self.assertEqual(
- result,
- [
- Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
- Row(label=2.0, point=ExamplePoint(3.0, 4.0))
- ]
- )
-
- def test_cast_to_string_with_udt(self):
- from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
- from pyspark.sql.functions import col
- row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
- schema = StructType([StructField("point", ExamplePointUDT(), False),
- StructField("pypoint", PythonOnlyUDT(), False)])
- df = self.spark.createDataFrame([row], schema)
-
- result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
- self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
-
- def test_column_operators(self):
- ci = self.df.key
- cs = self.df.value
- c = ci == cs
- self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
- rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
- self.assertTrue(all(isinstance(c, Column) for c in rcc))
- cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
- self.assertTrue(all(isinstance(c, Column) for c in cb))
- cbool = (ci & ci), (ci | ci), (~ci)
- self.assertTrue(all(isinstance(c, Column) for c in cbool))
- css = cs.contains('a'), cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(),\
- cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
- self.assertTrue(all(isinstance(c, Column) for c in css))
- self.assertTrue(isinstance(ci.cast(LongType()), Column))
- self.assertRaisesRegexp(ValueError,
- "Cannot apply 'in' operator against a column",
- lambda: 1 in cs)
-
- def test_column_getitem(self):
- from pyspark.sql.functions import col
-
- self.assertIsInstance(col("foo")[1:3], Column)
- self.assertIsInstance(col("foo")[0], Column)
- self.assertIsInstance(col("foo")["bar"], Column)
- self.assertRaises(ValueError, lambda: col("foo")[0:10:2])
-
- def test_column_select(self):
- df = self.df
- self.assertEqual(self.testData, df.select("*").collect())
- self.assertEqual(self.testData, df.select(df.key, df.value).collect())
- self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
-
- def test_freqItems(self):
- vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
- df = self.sc.parallelize(vals).toDF()
- items = df.stat.freqItems(("a", "b"), 0.4).collect()[0]
- self.assertTrue(1 in items[0])
- self.assertTrue(-2.0 in items[1])
-
- def test_aggregator(self):
- df = self.df
- g = df.groupBy()
- self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
- self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
-
- from pyspark.sql import functions
- self.assertEqual((0, u'99'),
- tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
- self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0])
- self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
-
- def test_first_last_ignorenulls(self):
- from pyspark.sql import functions
- df = self.spark.range(0, 100)
- df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
- df3 = df2.select(functions.first(df2.id, False).alias('a'),
- functions.first(df2.id, True).alias('b'),
- functions.last(df2.id, False).alias('c'),
- functions.last(df2.id, True).alias('d'))
- self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
-
- def test_approxQuantile(self):
- df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
- for f in ["a", u"a"]:
- aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
- self.assertTrue(isinstance(aq, list))
- self.assertEqual(len(aq), 3)
- self.assertTrue(all(isinstance(q, float) for q in aq))
- aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
- self.assertTrue(isinstance(aqs, list))
- self.assertEqual(len(aqs), 2)
- self.assertTrue(isinstance(aqs[0], list))
- self.assertEqual(len(aqs[0]), 3)
- self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
- self.assertTrue(isinstance(aqs[1], list))
- self.assertEqual(len(aqs[1]), 3)
- self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
- aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
- self.assertTrue(isinstance(aqt, list))
- self.assertEqual(len(aqt), 2)
- self.assertTrue(isinstance(aqt[0], list))
- self.assertEqual(len(aqt[0]), 3)
- self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
- self.assertTrue(isinstance(aqt[1], list))
- self.assertEqual(len(aqt[1]), 3)
- self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
- self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
- self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
- self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
-
- def test_corr(self):
- import math
- df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
- corr = df.stat.corr(u"a", "b")
- self.assertTrue(abs(corr - 0.95734012) < 1e-6)
-
- def test_sampleby(self):
- df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF()
- sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
- self.assertTrue(sampled.count() == 3)
-
- def test_cov(self):
- df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
- cov = df.stat.cov(u"a", "b")
- self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
-
- def test_crosstab(self):
- df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
- ct = df.stat.crosstab(u"a", "b").collect()
- ct = sorted(ct, key=lambda x: x[0])
- for i, row in enumerate(ct):
- self.assertEqual(row[0], str(i))
- self.assertTrue(row[1], 1)
- self.assertTrue(row[2], 1)
-
- def test_math_functions(self):
- df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
- from pyspark.sql import functions
- import math
-
- def get_values(l):
- return [j[0] for j in l]
-
- def assert_close(a, b):
- c = get_values(b)
- diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
- return sum(diff) == len(a)
- assert_close([math.cos(i) for i in range(10)],
- df.select(functions.cos(df.a)).collect())
- assert_close([math.cos(i) for i in range(10)],
- df.select(functions.cos("a")).collect())
- assert_close([math.sin(i) for i in range(10)],
- df.select(functions.sin(df.a)).collect())
- assert_close([math.sin(i) for i in range(10)],
- df.select(functions.sin(df['a'])).collect())
- assert_close([math.pow(i, 2 * i) for i in range(10)],
- df.select(functions.pow(df.a, df.b)).collect())
- assert_close([math.pow(i, 2) for i in range(10)],
- df.select(functions.pow(df.a, 2)).collect())
- assert_close([math.pow(i, 2) for i in range(10)],
- df.select(functions.pow(df.a, 2.0)).collect())
- assert_close([math.hypot(i, 2 * i) for i in range(10)],
- df.select(functions.hypot(df.a, df.b)).collect())
-
- def test_rand_functions(self):
- df = self.df
- from pyspark.sql import functions
- rnd = df.select('key', functions.rand()).collect()
- for row in rnd:
- assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
- rndn = df.select('key', functions.randn(5)).collect()
- for row in rndn:
- assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
-
- # If the specified seed is 0, we should use it.
- # https://issues.apache.org/jira/browse/SPARK-9691
- rnd1 = df.select('key', functions.rand(0)).collect()
- rnd2 = df.select('key', functions.rand(0)).collect()
- self.assertEqual(sorted(rnd1), sorted(rnd2))
-
- rndn1 = df.select('key', functions.randn(0)).collect()
- rndn2 = df.select('key', functions.randn(0)).collect()
- self.assertEqual(sorted(rndn1), sorted(rndn2))
-
- def test_string_functions(self):
- from pyspark.sql.functions import col, lit
- df = self.spark.createDataFrame([['nick']], schema=['name'])
- self.assertRaisesRegexp(
- TypeError,
- "must be the same type",
- lambda: df.select(col('name').substr(0, lit(1))))
- if sys.version_info.major == 2:
- self.assertRaises(
- TypeError,
- lambda: df.select(col('name').substr(long(0), long(1))))
-
- def test_array_contains_function(self):
- from pyspark.sql.functions import array_contains
-
- df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data'])
- actual = df.select(array_contains(df.data, "1").alias('b')).collect()
- self.assertEqual([Row(b=True), Row(b=False)], actual)
-
- def test_between_function(self):
- df = self.sc.parallelize([
- Row(a=1, b=2, c=3),
- Row(a=2, b=1, c=3),
- Row(a=4, b=1, c=4)]).toDF()
- self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
- df.filter(df.a.between(df.b, df.c)).collect())
-
- def test_struct_type(self):
- struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
- struct2 = StructType([StructField("f1", StringType(), True),
- StructField("f2", StringType(), True, None)])
- self.assertEqual(struct1.fieldNames(), struct2.names)
- self.assertEqual(struct1, struct2)
-
- struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
- struct2 = StructType([StructField("f1", StringType(), True)])
- self.assertNotEqual(struct1.fieldNames(), struct2.names)
- self.assertNotEqual(struct1, struct2)
-
- struct1 = (StructType().add(StructField("f1", StringType(), True))
- .add(StructField("f2", StringType(), True, None)))
- struct2 = StructType([StructField("f1", StringType(), True),
- StructField("f2", StringType(), True, None)])
- self.assertEqual(struct1.fieldNames(), struct2.names)
- self.assertEqual(struct1, struct2)
-
- struct1 = (StructType().add(StructField("f1", StringType(), True))
- .add(StructField("f2", StringType(), True, None)))
- struct2 = StructType([StructField("f1", StringType(), True)])
- self.assertNotEqual(struct1.fieldNames(), struct2.names)
- self.assertNotEqual(struct1, struct2)
-
- # Catch exception raised during improper construction
- self.assertRaises(ValueError, lambda: StructType().add("name"))
-
- struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
- for field in struct1:
- self.assertIsInstance(field, StructField)
-
- struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
- self.assertEqual(len(struct1), 2)
-
- struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
- self.assertIs(struct1["f1"], struct1.fields[0])
- self.assertIs(struct1[0], struct1.fields[0])
- self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
- self.assertRaises(KeyError, lambda: struct1["f9"])
- self.assertRaises(IndexError, lambda: struct1[9])
- self.assertRaises(TypeError, lambda: struct1[9.9])
-
- def test_struct_type_to_internal(self):
- # Verify when not needSerializeAnyField
- struct = StructType().add("b", StringType()).add("a", StringType())
- string_a = "value_a"
- string_b = "value_b"
- row = Row(a=string_a, b=string_b)
- tupleResult = struct.toInternal(row)
- # Reversed because of struct
- self.assertEqual(tupleResult, (string_b, string_a))
-
- # Verify when needSerializeAnyField
- struct1 = StructType().add("b", TimestampType()).add("a", TimestampType())
- timestamp_a = datetime.datetime(2018, 1, 1, 1, 1, 1)
- timestamp_b = datetime.datetime(2019, 1, 1, 1, 1, 1)
- row = Row(a=timestamp_a, b=timestamp_b)
- tupleResult = struct1.toInternal(row)
- # Reversed because of struct
- d = 1000000
- ts_b = tupleResult[0]
- new_timestamp_b = datetime.datetime.fromtimestamp(ts_b // d).replace(microsecond=ts_b % d)
- ts_a = tupleResult[1]
- new_timestamp_a = datetime.datetime.fromtimestamp(ts_a // d).replace(microsecond=ts_a % d)
- self.assertEqual(timestamp_a, new_timestamp_a)
- self.assertEqual(timestamp_b, new_timestamp_b)
-
- def test_parse_datatype_string(self):
- from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
- for k, t in _all_atomic_types.items():
- if t != NullType:
- self.assertEqual(t(), _parse_datatype_string(k))
- self.assertEqual(IntegerType(), _parse_datatype_string("int"))
- self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)"))
- self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
- self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
- self.assertEqual(
- ArrayType(IntegerType()),
- _parse_datatype_string("array"))
- self.assertEqual(
- MapType(IntegerType(), DoubleType()),
- _parse_datatype_string("map< int, double >"))
- self.assertEqual(
- StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
- _parse_datatype_string("struct"))
- self.assertEqual(
- StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
- _parse_datatype_string("a:int, c:double"))
- self.assertEqual(
- StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
- _parse_datatype_string("a INT, c DOUBLE"))
-
- def test_metadata_null(self):
- schema = StructType([StructField("f1", StringType(), True, None),
- StructField("f2", StringType(), True, {'a': None})])
- rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
- self.spark.createDataFrame(rdd, schema)
-
- def test_save_and_load(self):
- df = self.df
- tmpPath = tempfile.mkdtemp()
- shutil.rmtree(tmpPath)
- df.write.json(tmpPath)
- actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- schema = StructType([StructField("value", StringType(), True)])
- actual = self.spark.read.json(tmpPath, schema)
- self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
-
- df.write.json(tmpPath, "overwrite")
- actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- df.write.save(format="json", mode="overwrite", path=tmpPath,
- noUse="this options will not be used in save.")
- actual = self.spark.read.load(format="json", path=tmpPath,
- noUse="this options will not be used in load.")
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
- "org.apache.spark.sql.parquet")
- self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
- actual = self.spark.read.load(path=tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
-
- csvpath = os.path.join(tempfile.mkdtemp(), 'data')
- df.write.option('quote', None).format('csv').save(csvpath)
-
- shutil.rmtree(tmpPath)
-
- def test_save_and_load_builder(self):
- df = self.df
- tmpPath = tempfile.mkdtemp()
- shutil.rmtree(tmpPath)
- df.write.json(tmpPath)
- actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- schema = StructType([StructField("value", StringType(), True)])
- actual = self.spark.read.json(tmpPath, schema)
- self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
-
- df.write.mode("overwrite").json(tmpPath)
- actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
- .option("noUse", "this option will not be used in save.")\
- .format("json").save(path=tmpPath)
- actual =\
- self.spark.read.format("json")\
- .load(path=tmpPath, noUse="this options will not be used in load.")
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
- "org.apache.spark.sql.parquet")
- self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
- actual = self.spark.read.load(path=tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
-
- shutil.rmtree(tmpPath)
-
- def test_stream_trigger(self):
- df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
-
- # Should take at least one arg
- try:
- df.writeStream.trigger()
- except ValueError:
- pass
-
- # Should not take multiple args
- try:
- df.writeStream.trigger(once=True, processingTime='5 seconds')
- except ValueError:
- pass
-
- # Should not take multiple args
- try:
- df.writeStream.trigger(processingTime='5 seconds', continuous='1 second')
- except ValueError:
- pass
-
- # Should take only keyword args
- try:
- df.writeStream.trigger('5 seconds')
- self.fail("Should have thrown an exception")
- except TypeError:
- pass
-
- def test_stream_read_options(self):
- schema = StructType([StructField("data", StringType(), False)])
- df = self.spark.readStream\
- .format('text')\
- .option('path', 'python/test_support/sql/streaming')\
- .schema(schema)\
- .load()
- self.assertTrue(df.isStreaming)
- self.assertEqual(df.schema.simpleString(), "struct")
-
- def test_stream_read_options_overwrite(self):
- bad_schema = StructType([StructField("test", IntegerType(), False)])
- schema = StructType([StructField("data", StringType(), False)])
- df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \
- .schema(bad_schema)\
- .load(path='python/test_support/sql/streaming', schema=schema, format='text')
- self.assertTrue(df.isStreaming)
- self.assertEqual(df.schema.simpleString(), "struct")
-
- def test_stream_save_options(self):
- df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \
- .withColumn('id', lit(1))
- for q in self.spark._wrapped.streams.active:
- q.stop()
- tmpPath = tempfile.mkdtemp()
- shutil.rmtree(tmpPath)
- self.assertTrue(df.isStreaming)
- out = os.path.join(tmpPath, 'out')
- chk = os.path.join(tmpPath, 'chk')
- q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \
- .format('parquet').partitionBy('id').outputMode('append').option('path', out).start()
- try:
- self.assertEqual(q.name, 'this_query')
- self.assertTrue(q.isActive)
- q.processAllAvailable()
- output_files = []
- for _, _, files in os.walk(out):
- output_files.extend([f for f in files if not f.startswith('.')])
- self.assertTrue(len(output_files) > 0)
- self.assertTrue(len(os.listdir(chk)) > 0)
- finally:
- q.stop()
- shutil.rmtree(tmpPath)
-
- def test_stream_save_options_overwrite(self):
- df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
- for q in self.spark._wrapped.streams.active:
- q.stop()
- tmpPath = tempfile.mkdtemp()
- shutil.rmtree(tmpPath)
- self.assertTrue(df.isStreaming)
- out = os.path.join(tmpPath, 'out')
- chk = os.path.join(tmpPath, 'chk')
- fake1 = os.path.join(tmpPath, 'fake1')
- fake2 = os.path.join(tmpPath, 'fake2')
- q = df.writeStream.option('checkpointLocation', fake1)\
- .format('memory').option('path', fake2) \
- .queryName('fake_query').outputMode('append') \
- .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
-
- try:
- self.assertEqual(q.name, 'this_query')
- self.assertTrue(q.isActive)
- q.processAllAvailable()
- output_files = []
- for _, _, files in os.walk(out):
- output_files.extend([f for f in files if not f.startswith('.')])
- self.assertTrue(len(output_files) > 0)
- self.assertTrue(len(os.listdir(chk)) > 0)
- self.assertFalse(os.path.isdir(fake1)) # should not have been created
- self.assertFalse(os.path.isdir(fake2)) # should not have been created
- finally:
- q.stop()
- shutil.rmtree(tmpPath)
-
- def test_stream_status_and_progress(self):
- df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
- for q in self.spark._wrapped.streams.active:
- q.stop()
- tmpPath = tempfile.mkdtemp()
- shutil.rmtree(tmpPath)
- self.assertTrue(df.isStreaming)
- out = os.path.join(tmpPath, 'out')
- chk = os.path.join(tmpPath, 'chk')
-
- def func(x):
- time.sleep(1)
- return x
-
- from pyspark.sql.functions import col, udf
- sleep_udf = udf(func)
-
- # Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there
- # were no updates.
- q = df.select(sleep_udf(col("value")).alias('value')).writeStream \
- .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
- try:
- # "lastProgress" will return None in most cases. However, as it may be flaky when
- # Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress"
- # may throw error with a high chance and make this test flaky, so we should still be
- # able to detect broken codes.
- q.lastProgress
-
- q.processAllAvailable()
- lastProgress = q.lastProgress
- recentProgress = q.recentProgress
- status = q.status
- self.assertEqual(lastProgress['name'], q.name)
- self.assertEqual(lastProgress['id'], q.id)
- self.assertTrue(any(p == lastProgress for p in recentProgress))
- self.assertTrue(
- "message" in status and
- "isDataAvailable" in status and
- "isTriggerActive" in status)
- finally:
- q.stop()
- shutil.rmtree(tmpPath)
-
- def test_stream_await_termination(self):
- df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
- for q in self.spark._wrapped.streams.active:
- q.stop()
- tmpPath = tempfile.mkdtemp()
- shutil.rmtree(tmpPath)
- self.assertTrue(df.isStreaming)
- out = os.path.join(tmpPath, 'out')
- chk = os.path.join(tmpPath, 'chk')
- q = df.writeStream\
- .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
- try:
- self.assertTrue(q.isActive)
- try:
- q.awaitTermination("hello")
- self.fail("Expected a value exception")
- except ValueError:
- pass
- now = time.time()
- # test should take at least 2 seconds
- res = q.awaitTermination(2.6)
- duration = time.time() - now
- self.assertTrue(duration >= 2)
- self.assertFalse(res)
- finally:
- q.stop()
- shutil.rmtree(tmpPath)
-
- def test_stream_exception(self):
- sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
- sq = sdf.writeStream.format('memory').queryName('query_explain').start()
- try:
- sq.processAllAvailable()
- self.assertEqual(sq.exception(), None)
- finally:
- sq.stop()
-
- from pyspark.sql.functions import col, udf
- from pyspark.sql.utils import StreamingQueryException
- bad_udf = udf(lambda x: 1 / 0)
- sq = sdf.select(bad_udf(col("value")))\
- .writeStream\
- .format('memory')\
- .queryName('this_query')\
- .start()
- try:
- # Process some data to fail the query
- sq.processAllAvailable()
- self.fail("bad udf should fail the query")
- except StreamingQueryException as e:
- # This is expected
- self.assertTrue("ZeroDivisionError" in e.desc)
- finally:
- sq.stop()
- self.assertTrue(type(sq.exception()) is StreamingQueryException)
- self.assertTrue("ZeroDivisionError" in sq.exception().desc)
-
- def test_query_manager_await_termination(self):
- df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
- for q in self.spark._wrapped.streams.active:
- q.stop()
- tmpPath = tempfile.mkdtemp()
- shutil.rmtree(tmpPath)
- self.assertTrue(df.isStreaming)
- out = os.path.join(tmpPath, 'out')
- chk = os.path.join(tmpPath, 'chk')
- q = df.writeStream\
- .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
- try:
- self.assertTrue(q.isActive)
- try:
- self.spark._wrapped.streams.awaitAnyTermination("hello")
- self.fail("Expected a value exception")
- except ValueError:
- pass
- now = time.time()
- # test should take at least 2 seconds
- res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
- duration = time.time() - now
- self.assertTrue(duration >= 2)
- self.assertFalse(res)
- finally:
- q.stop()
- shutil.rmtree(tmpPath)
-
- class ForeachWriterTester:
-
- def __init__(self, spark):
- self.spark = spark
-
- def write_open_event(self, partitionId, epochId):
- self._write_event(
- self.open_events_dir,
- {'partition': partitionId, 'epoch': epochId})
-
- def write_process_event(self, row):
- self._write_event(self.process_events_dir, {'value': 'text'})
-
- def write_close_event(self, error):
- self._write_event(self.close_events_dir, {'error': str(error)})
-
- def write_input_file(self):
- self._write_event(self.input_dir, "text")
-
- def open_events(self):
- return self._read_events(self.open_events_dir, 'partition INT, epoch INT')
-
- def process_events(self):
- return self._read_events(self.process_events_dir, 'value STRING')
-
- def close_events(self):
- return self._read_events(self.close_events_dir, 'error STRING')
-
- def run_streaming_query_on_writer(self, writer, num_files):
- self._reset()
- try:
- sdf = self.spark.readStream.format('text').load(self.input_dir)
- sq = sdf.writeStream.foreach(writer).start()
- for i in range(num_files):
- self.write_input_file()
- sq.processAllAvailable()
- finally:
- self.stop_all()
-
- def assert_invalid_writer(self, writer, msg=None):
- self._reset()
- try:
- sdf = self.spark.readStream.format('text').load(self.input_dir)
- sq = sdf.writeStream.foreach(writer).start()
- self.write_input_file()
- sq.processAllAvailable()
- self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected
- except Exception as e:
- if msg:
- assert msg in str(e), "%s not in %s" % (msg, str(e))
-
- finally:
- self.stop_all()
-
- def stop_all(self):
- for q in self.spark._wrapped.streams.active:
- q.stop()
-
- def _reset(self):
- self.input_dir = tempfile.mkdtemp()
- self.open_events_dir = tempfile.mkdtemp()
- self.process_events_dir = tempfile.mkdtemp()
- self.close_events_dir = tempfile.mkdtemp()
-
- def _read_events(self, dir, json):
- rows = self.spark.read.schema(json).json(dir).collect()
- dicts = [row.asDict() for row in rows]
- return dicts
-
- def _write_event(self, dir, event):
- import uuid
- with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f:
- f.write("%s\n" % str(event))
-
- def __getstate__(self):
- return (self.open_events_dir, self.process_events_dir, self.close_events_dir)
-
- def __setstate__(self, state):
- self.open_events_dir, self.process_events_dir, self.close_events_dir = state
-
- # Those foreach tests are failed in Python 3.6 and macOS High Sierra by defined rules
- # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html
- # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES.
- def test_streaming_foreach_with_simple_function(self):
- tester = self.ForeachWriterTester(self.spark)
-
- def foreach_func(row):
- tester.write_process_event(row)
-
- tester.run_streaming_query_on_writer(foreach_func, 2)
- self.assertEqual(len(tester.process_events()), 2)
-
- def test_streaming_foreach_with_basic_open_process_close(self):
- tester = self.ForeachWriterTester(self.spark)
-
- class ForeachWriter:
- def open(self, partitionId, epochId):
- tester.write_open_event(partitionId, epochId)
- return True
-
- def process(self, row):
- tester.write_process_event(row)
-
- def close(self, error):
- tester.write_close_event(error)
-
- tester.run_streaming_query_on_writer(ForeachWriter(), 2)
-
- open_events = tester.open_events()
- self.assertEqual(len(open_events), 2)
- self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1})
-
- self.assertEqual(len(tester.process_events()), 2)
-
- close_events = tester.close_events()
- self.assertEqual(len(close_events), 2)
- self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
-
- def test_streaming_foreach_with_open_returning_false(self):
- tester = self.ForeachWriterTester(self.spark)
-
- class ForeachWriter:
- def open(self, partition_id, epoch_id):
- tester.write_open_event(partition_id, epoch_id)
- return False
-
- def process(self, row):
- tester.write_process_event(row)
-
- def close(self, error):
- tester.write_close_event(error)
-
- tester.run_streaming_query_on_writer(ForeachWriter(), 2)
-
- self.assertEqual(len(tester.open_events()), 2)
-
- self.assertEqual(len(tester.process_events()), 0) # no row was processed
-
- close_events = tester.close_events()
- self.assertEqual(len(close_events), 2)
- self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
-
- def test_streaming_foreach_without_open_method(self):
- tester = self.ForeachWriterTester(self.spark)
-
- class ForeachWriter:
- def process(self, row):
- tester.write_process_event(row)
-
- def close(self, error):
- tester.write_close_event(error)
-
- tester.run_streaming_query_on_writer(ForeachWriter(), 2)
- self.assertEqual(len(tester.open_events()), 0) # no open events
- self.assertEqual(len(tester.process_events()), 2)
- self.assertEqual(len(tester.close_events()), 2)
-
- def test_streaming_foreach_without_close_method(self):
- tester = self.ForeachWriterTester(self.spark)
-
- class ForeachWriter:
- def open(self, partition_id, epoch_id):
- tester.write_open_event(partition_id, epoch_id)
- return True
-
- def process(self, row):
- tester.write_process_event(row)
-
- tester.run_streaming_query_on_writer(ForeachWriter(), 2)
- self.assertEqual(len(tester.open_events()), 2) # no open events
- self.assertEqual(len(tester.process_events()), 2)
- self.assertEqual(len(tester.close_events()), 0)
-
- def test_streaming_foreach_without_open_and_close_methods(self):
- tester = self.ForeachWriterTester(self.spark)
-
- class ForeachWriter:
- def process(self, row):
- tester.write_process_event(row)
-
- tester.run_streaming_query_on_writer(ForeachWriter(), 2)
- self.assertEqual(len(tester.open_events()), 0) # no open events
- self.assertEqual(len(tester.process_events()), 2)
- self.assertEqual(len(tester.close_events()), 0)
-
- def test_streaming_foreach_with_process_throwing_error(self):
- from pyspark.sql.utils import StreamingQueryException
-
- tester = self.ForeachWriterTester(self.spark)
-
- class ForeachWriter:
- def process(self, row):
- raise Exception("test error")
-
- def close(self, error):
- tester.write_close_event(error)
-
- try:
- tester.run_streaming_query_on_writer(ForeachWriter(), 1)
- self.fail("bad writer did not fail the query") # this is not expected
- except StreamingQueryException as e:
- # TODO: Verify whether original error message is inside the exception
- pass
-
- self.assertEqual(len(tester.process_events()), 0) # no row was processed
- close_events = tester.close_events()
- self.assertEqual(len(close_events), 1)
- # TODO: Verify whether original error message is inside the exception
-
- def test_streaming_foreach_with_invalid_writers(self):
-
- tester = self.ForeachWriterTester(self.spark)
-
- def func_with_iterator_input(iter):
- for x in iter:
- print(x)
-
- tester.assert_invalid_writer(func_with_iterator_input)
-
- class WriterWithoutProcess:
- def open(self, partition):
- pass
-
- tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'")
-
- class WriterWithNonCallableProcess():
- process = True
-
- tester.assert_invalid_writer(WriterWithNonCallableProcess(),
- "'process' in provided object is not callable")
-
- class WriterWithNoParamProcess():
- def process(self):
- pass
-
- tester.assert_invalid_writer(WriterWithNoParamProcess())
-
- # Abstract class for tests below
- class WithProcess():
- def process(self, row):
- pass
-
- class WriterWithNonCallableOpen(WithProcess):
- open = True
-
- tester.assert_invalid_writer(WriterWithNonCallableOpen(),
- "'open' in provided object is not callable")
-
- class WriterWithNoParamOpen(WithProcess):
- def open(self):
- pass
-
- tester.assert_invalid_writer(WriterWithNoParamOpen())
-
- class WriterWithNonCallableClose(WithProcess):
- close = True
-
- tester.assert_invalid_writer(WriterWithNonCallableClose(),
- "'close' in provided object is not callable")
-
- def test_streaming_foreachBatch(self):
- q = None
- collected = dict()
-
- def collectBatch(batch_df, batch_id):
- collected[batch_id] = batch_df.collect()
-
- try:
- df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
- q = df.writeStream.foreachBatch(collectBatch).start()
- q.processAllAvailable()
- self.assertTrue(0 in collected)
- self.assertTrue(len(collected[0]), 2)
- finally:
- if q:
- q.stop()
-
- def test_streaming_foreachBatch_propagates_python_errors(self):
- from pyspark.sql.utils import StreamingQueryException
-
- q = None
-
- def collectBatch(df, id):
- raise Exception("this should fail the query")
-
- try:
- df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
- q = df.writeStream.foreachBatch(collectBatch).start()
- q.processAllAvailable()
- self.fail("Expected a failure")
- except StreamingQueryException as e:
- self.assertTrue("this should fail" in str(e))
- finally:
- if q:
- q.stop()
-
- def test_help_command(self):
- # Regression test for SPARK-5464
- rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- df = self.spark.read.json(rdd)
- # render_doc() reproduces the help() exception without printing output
- pydoc.render_doc(df)
- pydoc.render_doc(df.foo)
- pydoc.render_doc(df.take(1))
-
- def test_access_column(self):
- df = self.df
- self.assertTrue(isinstance(df.key, Column))
- self.assertTrue(isinstance(df['key'], Column))
- self.assertTrue(isinstance(df[0], Column))
- self.assertRaises(IndexError, lambda: df[2])
- self.assertRaises(AnalysisException, lambda: df["bad_key"])
- self.assertRaises(TypeError, lambda: df[{}])
-
- def test_column_name_with_non_ascii(self):
- if sys.version >= '3':
- columnName = "数量"
- self.assertTrue(isinstance(columnName, str))
- else:
- columnName = unicode("数量", "utf-8")
- self.assertTrue(isinstance(columnName, unicode))
- schema = StructType([StructField(columnName, LongType(), True)])
- df = self.spark.createDataFrame([(1,)], schema)
- self.assertEqual(schema, df.schema)
- self.assertEqual("DataFrame[数量: bigint]", str(df))
- self.assertEqual([("数量", 'bigint')], df.dtypes)
- self.assertEqual(1, df.select("数量").first()[0])
- self.assertEqual(1, df.select(df["数量"]).first()[0])
-
- def test_access_nested_types(self):
- df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
- self.assertEqual(1, df.select(df.l[0]).first()[0])
- self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
- self.assertEqual(1, df.select(df.r.a).first()[0])
- self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
- self.assertEqual("v", df.select(df.d["k"]).first()[0])
- self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
-
- def test_field_accessor(self):
- df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
- self.assertEqual(1, df.select(df.l[0]).first()[0])
- self.assertEqual(1, df.select(df.r["a"]).first()[0])
- self.assertEqual(1, df.select(df["r.a"]).first()[0])
- self.assertEqual("b", df.select(df.r["b"]).first()[0])
- self.assertEqual("b", df.select(df["r.b"]).first()[0])
- self.assertEqual("v", df.select(df.d["k"]).first()[0])
-
- def test_infer_long_type(self):
- longrow = [Row(f1='a', f2=100000000000000)]
- df = self.sc.parallelize(longrow).toDF()
- self.assertEqual(df.schema.fields[1].dataType, LongType())
-
- # this saving as Parquet caused issues as well.
- output_dir = os.path.join(self.tempdir.name, "infer_long_type")
- df.write.parquet(output_dir)
- df1 = self.spark.read.parquet(output_dir)
- self.assertEqual('a', df1.first().f1)
- self.assertEqual(100000000000000, df1.first().f2)
-
- self.assertEqual(_infer_type(1), LongType())
- self.assertEqual(_infer_type(2**10), LongType())
- self.assertEqual(_infer_type(2**20), LongType())
- self.assertEqual(_infer_type(2**31 - 1), LongType())
- self.assertEqual(_infer_type(2**31), LongType())
- self.assertEqual(_infer_type(2**61), LongType())
- self.assertEqual(_infer_type(2**71), LongType())
-
- def test_merge_type(self):
- self.assertEqual(_merge_type(LongType(), NullType()), LongType())
- self.assertEqual(_merge_type(NullType(), LongType()), LongType())
-
- self.assertEqual(_merge_type(LongType(), LongType()), LongType())
-
- self.assertEqual(_merge_type(
- ArrayType(LongType()),
- ArrayType(LongType())
- ), ArrayType(LongType()))
- with self.assertRaisesRegexp(TypeError, 'element in array'):
- _merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
-
- self.assertEqual(_merge_type(
- MapType(StringType(), LongType()),
- MapType(StringType(), LongType())
- ), MapType(StringType(), LongType()))
- with self.assertRaisesRegexp(TypeError, 'key of map'):
- _merge_type(
- MapType(StringType(), LongType()),
- MapType(DoubleType(), LongType()))
- with self.assertRaisesRegexp(TypeError, 'value of map'):
- _merge_type(
- MapType(StringType(), LongType()),
- MapType(StringType(), DoubleType()))
-
- self.assertEqual(_merge_type(
- StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
- StructType([StructField("f1", LongType()), StructField("f2", StringType())])
- ), StructType([StructField("f1", LongType()), StructField("f2", StringType())]))
- with self.assertRaisesRegexp(TypeError, 'field f1'):
- _merge_type(
- StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
- StructType([StructField("f1", DoubleType()), StructField("f2", StringType())]))
-
- self.assertEqual(_merge_type(
- StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
- StructType([StructField("f1", StructType([StructField("f2", LongType())]))])
- ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))]))
- with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
- _merge_type(
- StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
- StructType([StructField("f1", StructType([StructField("f2", StringType())]))]))
-
- self.assertEqual(_merge_type(
- StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]),
- StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])
- ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]))
- with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
- _merge_type(
- StructType([
- StructField("f1", ArrayType(LongType())),
- StructField("f2", StringType())]),
- StructType([
- StructField("f1", ArrayType(DoubleType())),
- StructField("f2", StringType())]))
-
- self.assertEqual(_merge_type(
- StructType([
- StructField("f1", MapType(StringType(), LongType())),
- StructField("f2", StringType())]),
- StructType([
- StructField("f1", MapType(StringType(), LongType())),
- StructField("f2", StringType())])
- ), StructType([
- StructField("f1", MapType(StringType(), LongType())),
- StructField("f2", StringType())]))
- with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
- _merge_type(
- StructType([
- StructField("f1", MapType(StringType(), LongType())),
- StructField("f2", StringType())]),
- StructType([
- StructField("f1", MapType(StringType(), DoubleType())),
- StructField("f2", StringType())]))
-
- self.assertEqual(_merge_type(
- StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
- StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])
- ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]))
- with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'):
- _merge_type(
- StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
- StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))])
- )
-
- def test_filter_with_datetime(self):
- time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
- date = time.date()
- row = Row(date=date, time=time)
- df = self.spark.createDataFrame([row])
- self.assertEqual(1, df.filter(df.date == date).count())
- self.assertEqual(1, df.filter(df.time == time).count())
- self.assertEqual(0, df.filter(df.date > date).count())
- self.assertEqual(0, df.filter(df.time > time).count())
-
- def test_filter_with_datetime_timezone(self):
- dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0))
- dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1))
- row = Row(date=dt1)
- df = self.spark.createDataFrame([row])
- self.assertEqual(0, df.filter(df.date == dt2).count())
- self.assertEqual(1, df.filter(df.date > dt2).count())
- self.assertEqual(0, df.filter(df.date < dt2).count())
-
- def test_time_with_timezone(self):
- day = datetime.date.today()
- now = datetime.datetime.now()
- ts = time.mktime(now.timetuple())
- # class in __main__ is not serializable
- from pyspark.sql.tests import UTCOffsetTimezone
- utc = UTCOffsetTimezone()
- utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds
- # add microseconds to utcnow (keeping year,month,day,hour,minute,second)
- utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
- df = self.spark.createDataFrame([(day, now, utcnow)])
- day1, now1, utcnow1 = df.first()
- self.assertEqual(day1, day)
- self.assertEqual(now, now1)
- self.assertEqual(now, utcnow1)
-
- # regression test for SPARK-19561
- def test_datetime_at_epoch(self):
- epoch = datetime.datetime.fromtimestamp(0)
- df = self.spark.createDataFrame([Row(date=epoch)])
- first = df.select('date', lit(epoch).alias('lit_date')).first()
- self.assertEqual(first['date'], epoch)
- self.assertEqual(first['lit_date'], epoch)
-
- def test_dayofweek(self):
- from pyspark.sql.functions import dayofweek
- dt = datetime.datetime(2017, 11, 6)
- df = self.spark.createDataFrame([Row(date=dt)])
- row = df.select(dayofweek(df.date)).first()
- self.assertEqual(row[0], 2)
-
- def test_decimal(self):
- from decimal import Decimal
- schema = StructType([StructField("decimal", DecimalType(10, 5))])
- df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema)
- row = df.select(df.decimal + 1).first()
- self.assertEqual(row[0], Decimal("4.14159"))
- tmpPath = tempfile.mkdtemp()
- shutil.rmtree(tmpPath)
- df.write.parquet(tmpPath)
- df2 = self.spark.read.parquet(tmpPath)
- row = df2.first()
- self.assertEqual(row[0], Decimal("3.14159"))
-
- def test_dropna(self):
- schema = StructType([
- StructField("name", StringType(), True),
- StructField("age", IntegerType(), True),
- StructField("height", DoubleType(), True)])
-
- # shouldn't drop a non-null row
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', 50, 80.1)], schema).dropna().count(),
- 1)
-
- # dropping rows with a single null value
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', None, 80.1)], schema).dropna().count(),
- 0)
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', None, 80.1)], schema).dropna(how='any').count(),
- 0)
-
- # if how = 'all', only drop rows if all values are null
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', None, 80.1)], schema).dropna(how='all').count(),
- 1)
- self.assertEqual(self.spark.createDataFrame(
- [(None, None, None)], schema).dropna(how='all').count(),
- 0)
-
- # how and subset
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
- 1)
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
- 0)
-
- # threshold
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(),
- 1)
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', None, None)], schema).dropna(thresh=2).count(),
- 0)
-
- # threshold and subset
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
- 1)
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
- 0)
-
- # thresh should take precedence over how
- self.assertEqual(self.spark.createDataFrame(
- [(u'Alice', 50, None)], schema).dropna(
- how='any', thresh=2, subset=['name', 'age']).count(),
- 1)
-
- def test_fillna(self):
- schema = StructType([
- StructField("name", StringType(), True),
- StructField("age", IntegerType(), True),
- StructField("height", DoubleType(), True),
- StructField("spy", BooleanType(), True)])
-
- # fillna shouldn't change non-null values
- row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first()
- self.assertEqual(row.age, 10)
-
- # fillna with int
- row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first()
- self.assertEqual(row.age, 50)
- self.assertEqual(row.height, 50.0)
-
- # fillna with double
- row = self.spark.createDataFrame(
- [(u'Alice', None, None, None)], schema).fillna(50.1).first()
- self.assertEqual(row.age, 50)
- self.assertEqual(row.height, 50.1)
-
- # fillna with bool
- row = self.spark.createDataFrame(
- [(u'Alice', None, None, None)], schema).fillna(True).first()
- self.assertEqual(row.age, None)
- self.assertEqual(row.spy, True)
-
- # fillna with string
- row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first()
- self.assertEqual(row.name, u"hello")
- self.assertEqual(row.age, None)
-
- # fillna with subset specified for numeric cols
- row = self.spark.createDataFrame(
- [(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
- self.assertEqual(row.name, None)
- self.assertEqual(row.age, 50)
- self.assertEqual(row.height, None)
- self.assertEqual(row.spy, None)
-
- # fillna with subset specified for string cols
- row = self.spark.createDataFrame(
- [(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
- self.assertEqual(row.name, "haha")
- self.assertEqual(row.age, None)
- self.assertEqual(row.height, None)
- self.assertEqual(row.spy, None)
-
- # fillna with subset specified for bool cols
- row = self.spark.createDataFrame(
- [(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first()
- self.assertEqual(row.name, None)
- self.assertEqual(row.age, None)
- self.assertEqual(row.height, None)
- self.assertEqual(row.spy, True)
-
- # fillna with dictionary for boolean types
- row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first()
- self.assertEqual(row.a, True)
-
- def test_bitwise_operations(self):
- from pyspark.sql import functions
- row = Row(a=170, b=75)
- df = self.spark.createDataFrame([row])
- result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
- self.assertEqual(170 & 75, result['(a & b)'])
- result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
- self.assertEqual(170 | 75, result['(a | b)'])
- result = df.select(df.a.bitwiseXOR(df.b)).collect()[0].asDict()
- self.assertEqual(170 ^ 75, result['(a ^ b)'])
- result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict()
- self.assertEqual(~75, result['~b'])
-
- def test_expr(self):
- from pyspark.sql import functions
- row = Row(a="length string", b=75)
- df = self.spark.createDataFrame([row])
- result = df.select(functions.expr("length(a)")).collect()[0].asDict()
- self.assertEqual(13, result["length(a)"])
-
- def test_repartitionByRange_dataframe(self):
- schema = StructType([
- StructField("name", StringType(), True),
- StructField("age", IntegerType(), True),
- StructField("height", DoubleType(), True)])
-
- df1 = self.spark.createDataFrame(
- [(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema)
- df2 = self.spark.createDataFrame(
- [(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema)
-
- # test repartitionByRange(numPartitions, *cols)
- df3 = df1.repartitionByRange(2, "name", "age")
- self.assertEqual(df3.rdd.getNumPartitions(), 2)
- self.assertEqual(df3.rdd.first(), df2.rdd.first())
- self.assertEqual(df3.rdd.take(3), df2.rdd.take(3))
-
- # test repartitionByRange(numPartitions, *cols)
- df4 = df1.repartitionByRange(3, "name", "age")
- self.assertEqual(df4.rdd.getNumPartitions(), 3)
- self.assertEqual(df4.rdd.first(), df2.rdd.first())
- self.assertEqual(df4.rdd.take(3), df2.rdd.take(3))
-
- # test repartitionByRange(*cols)
- df5 = df1.repartitionByRange("name", "age")
- self.assertEqual(df5.rdd.first(), df2.rdd.first())
- self.assertEqual(df5.rdd.take(3), df2.rdd.take(3))
-
- def test_replace(self):
- schema = StructType([
- StructField("name", StringType(), True),
- StructField("age", IntegerType(), True),
- StructField("height", DoubleType(), True)])
-
- # replace with int
- row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
- self.assertEqual(row.age, 20)
- self.assertEqual(row.height, 20.0)
-
- # replace with double
- row = self.spark.createDataFrame(
- [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
- self.assertEqual(row.age, 82)
- self.assertEqual(row.height, 82.1)
-
- # replace with string
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
- self.assertEqual(row.name, u"Ann")
- self.assertEqual(row.age, 10)
-
- # replace with subset specified by a string of a column name w/ actual change
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
- self.assertEqual(row.age, 20)
-
- # replace with subset specified by a string of a column name w/o actual change
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
- self.assertEqual(row.age, 10)
-
- # replace with subset specified with one column replaced, another column not in subset
- # stays unchanged.
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
- self.assertEqual(row.name, u'Alice')
- self.assertEqual(row.age, 20)
- self.assertEqual(row.height, 10.0)
-
- # replace with subset specified but no column will be replaced
- row = self.spark.createDataFrame(
- [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
- self.assertEqual(row.name, u'Alice')
- self.assertEqual(row.age, 10)
- self.assertEqual(row.height, None)
-
- # replace with lists
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first()
- self.assertTupleEqual(row, (u'Ann', 10, 80.1))
-
- # replace with dict
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first()
- self.assertTupleEqual(row, (u'Alice', 11, 80.1))
-
- # test backward compatibility with dummy value
- dummy_value = 1
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first()
- self.assertTupleEqual(row, (u'Bob', 10, 80.1))
-
- # test dict with mixed numerics
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first()
- self.assertTupleEqual(row, (u'Alice', -10, 90.5))
-
- # replace with tuples
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first()
- self.assertTupleEqual(row, (u'Bob', 10, 80.1))
-
- # replace multiple columns
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first()
- self.assertTupleEqual(row, (u'Alice', 20, 90.0))
-
- # test for mixed numerics
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first()
- self.assertTupleEqual(row, (u'Alice', 20, 90.5))
-
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first()
- self.assertTupleEqual(row, (u'Alice', 20, 90.5))
-
- # replace with boolean
- row = (self
- .spark.createDataFrame([(u'Alice', 10, 80.0)], schema)
- .selectExpr("name = 'Bob'", 'age <= 15')
- .replace(False, True).first())
- self.assertTupleEqual(row, (True, True))
-
- # replace string with None and then drop None rows
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna()
- self.assertEqual(row.count(), 0)
-
- # replace with number and None
- row = self.spark.createDataFrame(
- [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first()
- self.assertTupleEqual(row, (u'Alice', 20, None))
-
- # should fail if subset is not list, tuple or None
- with self.assertRaises(ValueError):
- self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first()
-
- # should fail if to_replace and value have different length
- with self.assertRaises(ValueError):
- self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first()
-
- # should fail if when received unexpected type
- with self.assertRaises(ValueError):
- from datetime import datetime
- self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first()
-
- # should fail if provided mixed type replacements
- with self.assertRaises(ValueError):
- self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first()
-
- with self.assertRaises(ValueError):
- self.spark.createDataFrame(
- [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
-
- with self.assertRaisesRegexp(
- TypeError,
- 'value argument is required when to_replace is not a dictionary.'):
- self.spark.createDataFrame(
- [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first()
-
- def test_capture_analysis_exception(self):
- self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
- self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
-
- def test_capture_parse_exception(self):
- self.assertRaises(ParseException, lambda: self.spark.sql("abc"))
-
- def test_capture_illegalargument_exception(self):
- self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
- lambda: self.spark.sql("SET mapred.reduce.tasks=-1"))
- df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
- self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
- lambda: df.select(sha2(df.a, 1024)).collect())
- try:
- df.select(sha2(df.a, 1024)).collect()
- except IllegalArgumentException as e:
- self.assertRegexpMatches(e.desc, "1024 is not in the permitted values")
- self.assertRegexpMatches(e.stackTrace,
- "org.apache.spark.sql.functions")
-
- def test_with_column_with_existing_name(self):
- keys = self.df.withColumn("key", self.df.key).select("key").collect()
- self.assertEqual([r.key for r in keys], list(range(100)))
-
- # regression test for SPARK-10417
- def test_column_iterator(self):
-
- def foo():
- for x in self.df.key:
- break
-
- self.assertRaises(TypeError, foo)
-
- # add test for SPARK-10577 (test broadcast join hint)
- def test_functions_broadcast(self):
- from pyspark.sql.functions import broadcast
-
- df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
- df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
-
- # equijoin - should be converted into broadcast join
- plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
- self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
-
- # no join key -- should not be a broadcast join
- plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan()
- self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
-
- # planner should not crash without a join
- broadcast(df1)._jdf.queryExecution().executedPlan()
-
- def test_generic_hints(self):
- from pyspark.sql import DataFrame
-
- df1 = self.spark.range(10e10).toDF("id")
- df2 = self.spark.range(10e10).toDF("id")
-
- self.assertIsInstance(df1.hint("broadcast"), DataFrame)
- self.assertIsInstance(df1.hint("broadcast", []), DataFrame)
-
- # Dummy rules
- self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame)
- self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame)
-
- plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan()
- self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
-
- def test_sample(self):
- self.assertRaisesRegexp(
- TypeError,
- "should be a bool, float and number",
- lambda: self.spark.range(1).sample())
-
- self.assertRaises(
- TypeError,
- lambda: self.spark.range(1).sample("a"))
-
- self.assertRaises(
- TypeError,
- lambda: self.spark.range(1).sample(seed="abc"))
-
- self.assertRaises(
- IllegalArgumentException,
- lambda: self.spark.range(1).sample(-1.0))
-
- def test_toDF_with_schema_string(self):
- data = [Row(key=i, value=str(i)) for i in range(100)]
- rdd = self.sc.parallelize(data, 5)
-
- df = rdd.toDF("key: int, value: string")
- self.assertEqual(df.schema.simpleString(), "struct")
- self.assertEqual(df.collect(), data)
-
- # different but compatible field types can be used.
- df = rdd.toDF("key: string, value: string")
- self.assertEqual(df.schema.simpleString(), "struct")
- self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
-
- # field names can differ.
- df = rdd.toDF(" a: int, b: string ")
- self.assertEqual(df.schema.simpleString(), "struct")
- self.assertEqual(df.collect(), data)
-
- # number of fields must match.
- self.assertRaisesRegexp(Exception, "Length of object",
- lambda: rdd.toDF("key: int").collect())
-
- # field types mismatch will cause exception at runtime.
- self.assertRaisesRegexp(Exception, "FloatType can not accept",
- lambda: rdd.toDF("key: float, value: string").collect())
-
- # flat schema values will be wrapped into row.
- df = rdd.map(lambda row: row.key).toDF("int")
- self.assertEqual(df.schema.simpleString(), "struct")
- self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
-
- # users can use DataType directly instead of data type string.
- df = rdd.map(lambda row: row.key).toDF(IntegerType())
- self.assertEqual(df.schema.simpleString(), "struct")
- self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
-
- def test_join_without_on(self):
- df1 = self.spark.range(1).toDF("a")
- df2 = self.spark.range(1).toDF("b")
-
- with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
- self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())
-
- with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
- actual = df1.join(df2, how="inner").collect()
- expected = [Row(a=0, b=0)]
- self.assertEqual(actual, expected)
-
- # Regression test for invalid join methods when on is None, Spark-14761
- def test_invalid_join_method(self):
- df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"])
- df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"])
- self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type"))
-
- # Cartesian products require cross join syntax
- def test_require_cross(self):
- from pyspark.sql.functions import broadcast
-
- df1 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
- df2 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
-
- # joins without conditions require cross join syntax
- self.assertRaises(AnalysisException, lambda: df1.join(df2).collect())
-
- # works with crossJoin
- self.assertEqual(1, df1.crossJoin(df2).count())
-
- def test_conf(self):
- spark = self.spark
- spark.conf.set("bogo", "sipeo")
- self.assertEqual(spark.conf.get("bogo"), "sipeo")
- spark.conf.set("bogo", "ta")
- self.assertEqual(spark.conf.get("bogo"), "ta")
- self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
- self.assertEqual(spark.conf.get("not.set", "ta"), "ta")
- self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set"))
- spark.conf.unset("bogo")
- self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
-
- self.assertEqual(spark.conf.get("hyukjin", None), None)
-
- # This returns 'STATIC' because it's the default value of
- # 'spark.sql.sources.partitionOverwriteMode', and `defaultValue` in
- # `spark.conf.get` is unset.
- self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode"), "STATIC")
-
- # This returns None because 'spark.sql.sources.partitionOverwriteMode' is unset, but
- # `defaultValue` in `spark.conf.get` is set to None.
- self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None)
-
- def test_current_database(self):
- spark = self.spark
- with self.database("some_db"):
- self.assertEquals(spark.catalog.currentDatabase(), "default")
- spark.sql("CREATE DATABASE some_db")
- spark.catalog.setCurrentDatabase("some_db")
- self.assertEquals(spark.catalog.currentDatabase(), "some_db")
- self.assertRaisesRegexp(
- AnalysisException,
- "does_not_exist",
- lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
-
- def test_list_databases(self):
- spark = self.spark
- with self.database("some_db"):
- databases = [db.name for db in spark.catalog.listDatabases()]
- self.assertEquals(databases, ["default"])
- spark.sql("CREATE DATABASE some_db")
- databases = [db.name for db in spark.catalog.listDatabases()]
- self.assertEquals(sorted(databases), ["default", "some_db"])
-
- def test_list_tables(self):
- from pyspark.sql.catalog import Table
- spark = self.spark
- with self.database("some_db"):
- spark.sql("CREATE DATABASE some_db")
- with self.table("tab1", "some_db.tab2"):
- with self.tempView("temp_tab"):
- self.assertEquals(spark.catalog.listTables(), [])
- self.assertEquals(spark.catalog.listTables("some_db"), [])
- spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
- spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
- spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
- tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
- tablesDefault = \
- sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
- tablesSomeDb = \
- sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
- self.assertEquals(tables, tablesDefault)
- self.assertEquals(len(tables), 2)
- self.assertEquals(len(tablesSomeDb), 2)
- self.assertEquals(tables[0], Table(
- name="tab1",
- database="default",
- description=None,
- tableType="MANAGED",
- isTemporary=False))
- self.assertEquals(tables[1], Table(
- name="temp_tab",
- database=None,
- description=None,
- tableType="TEMPORARY",
- isTemporary=True))
- self.assertEquals(tablesSomeDb[0], Table(
- name="tab2",
- database="some_db",
- description=None,
- tableType="MANAGED",
- isTemporary=False))
- self.assertEquals(tablesSomeDb[1], Table(
- name="temp_tab",
- database=None,
- description=None,
- tableType="TEMPORARY",
- isTemporary=True))
- self.assertRaisesRegexp(
- AnalysisException,
- "does_not_exist",
- lambda: spark.catalog.listTables("does_not_exist"))
-
- def test_list_functions(self):
- from pyspark.sql.catalog import Function
- spark = self.spark
- with self.database("some_db"):
- spark.sql("CREATE DATABASE some_db")
- functions = dict((f.name, f) for f in spark.catalog.listFunctions())
- functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
- self.assertTrue(len(functions) > 200)
- self.assertTrue("+" in functions)
- self.assertTrue("like" in functions)
- self.assertTrue("month" in functions)
- self.assertTrue("to_date" in functions)
- self.assertTrue("to_timestamp" in functions)
- self.assertTrue("to_unix_timestamp" in functions)
- self.assertTrue("current_database" in functions)
- self.assertEquals(functions["+"], Function(
- name="+",
- description=None,
- className="org.apache.spark.sql.catalyst.expressions.Add",
- isTemporary=True))
- self.assertEquals(functions, functionsDefault)
-
- with self.function("func1", "some_db.func2"):
- spark.catalog.registerFunction("temp_func", lambda x: str(x))
- spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
- spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
- newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
- newFunctionsSomeDb = \
- dict((f.name, f) for f in spark.catalog.listFunctions("some_db"))
- self.assertTrue(set(functions).issubset(set(newFunctions)))
- self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
- self.assertTrue("temp_func" in newFunctions)
- self.assertTrue("func1" in newFunctions)
- self.assertTrue("func2" not in newFunctions)
- self.assertTrue("temp_func" in newFunctionsSomeDb)
- self.assertTrue("func1" not in newFunctionsSomeDb)
- self.assertTrue("func2" in newFunctionsSomeDb)
- self.assertRaisesRegexp(
- AnalysisException,
- "does_not_exist",
- lambda: spark.catalog.listFunctions("does_not_exist"))
-
- def test_list_columns(self):
- from pyspark.sql.catalog import Column
- spark = self.spark
- with self.database("some_db"):
- spark.sql("CREATE DATABASE some_db")
- with self.table("tab1", "some_db.tab2"):
- spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
- spark.sql(
- "CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet")
- columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
- columnsDefault = \
- sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
- self.assertEquals(columns, columnsDefault)
- self.assertEquals(len(columns), 2)
- self.assertEquals(columns[0], Column(
- name="age",
- description=None,
- dataType="int",
- nullable=True,
- isPartition=False,
- isBucket=False))
- self.assertEquals(columns[1], Column(
- name="name",
- description=None,
- dataType="string",
- nullable=True,
- isPartition=False,
- isBucket=False))
- columns2 = \
- sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
- self.assertEquals(len(columns2), 2)
- self.assertEquals(columns2[0], Column(
- name="nickname",
- description=None,
- dataType="string",
- nullable=True,
- isPartition=False,
- isBucket=False))
- self.assertEquals(columns2[1], Column(
- name="tolerance",
- description=None,
- dataType="float",
- nullable=True,
- isPartition=False,
- isBucket=False))
- self.assertRaisesRegexp(
- AnalysisException,
- "tab2",
- lambda: spark.catalog.listColumns("tab2"))
- self.assertRaisesRegexp(
- AnalysisException,
- "does_not_exist",
- lambda: spark.catalog.listColumns("does_not_exist"))
-
- def test_cache(self):
- spark = self.spark
- with self.tempView("tab1", "tab2"):
- spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1")
- spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2")
- self.assertFalse(spark.catalog.isCached("tab1"))
- self.assertFalse(spark.catalog.isCached("tab2"))
- spark.catalog.cacheTable("tab1")
- self.assertTrue(spark.catalog.isCached("tab1"))
- self.assertFalse(spark.catalog.isCached("tab2"))
- spark.catalog.cacheTable("tab2")
- spark.catalog.uncacheTable("tab1")
- self.assertFalse(spark.catalog.isCached("tab1"))
- self.assertTrue(spark.catalog.isCached("tab2"))
- spark.catalog.clearCache()
- self.assertFalse(spark.catalog.isCached("tab1"))
- self.assertFalse(spark.catalog.isCached("tab2"))
- self.assertRaisesRegexp(
- AnalysisException,
- "does_not_exist",
- lambda: spark.catalog.isCached("does_not_exist"))
- self.assertRaisesRegexp(
- AnalysisException,
- "does_not_exist",
- lambda: spark.catalog.cacheTable("does_not_exist"))
- self.assertRaisesRegexp(
- AnalysisException,
- "does_not_exist",
- lambda: spark.catalog.uncacheTable("does_not_exist"))
-
- def test_read_text_file_list(self):
- df = self.spark.read.text(['python/test_support/sql/text-test.txt',
- 'python/test_support/sql/text-test.txt'])
- count = df.count()
- self.assertEquals(count, 4)
-
- def test_BinaryType_serialization(self):
- # Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808
- # The empty bytearray is test for SPARK-21534.
- schema = StructType([StructField('mybytes', BinaryType())])
- data = [[bytearray(b'here is my data')],
- [bytearray(b'and here is some more')],
- [bytearray(b'')]]
- df = self.spark.createDataFrame(data, schema=schema)
- df.collect()
-
- # test for SPARK-16542
- def test_array_types(self):
- # This test need to make sure that the Scala type selected is at least
- # as large as the python's types. This is necessary because python's
- # array types depend on C implementation on the machine. Therefore there
- # is no machine independent correspondence between python's array types
- # and Scala types.
- # See: https://docs.python.org/2/library/array.html
-
- def assertCollectSuccess(typecode, value):
- row = Row(myarray=array.array(typecode, [value]))
- df = self.spark.createDataFrame([row])
- self.assertEqual(df.first()["myarray"][0], value)
-
- # supported string types
- #
- # String types in python's array are "u" for Py_UNICODE and "c" for char.
- # "u" will be removed in python 4, and "c" is not supported in python 3.
- supported_string_types = []
- if sys.version_info[0] < 4:
- supported_string_types += ['u']
- # test unicode
- assertCollectSuccess('u', u'a')
- if sys.version_info[0] < 3:
- supported_string_types += ['c']
- # test string
- assertCollectSuccess('c', 'a')
-
- # supported float and double
- #
- # Test max, min, and precision for float and double, assuming IEEE 754
- # floating-point format.
- supported_fractional_types = ['f', 'd']
- assertCollectSuccess('f', ctypes.c_float(1e+38).value)
- assertCollectSuccess('f', ctypes.c_float(1e-38).value)
- assertCollectSuccess('f', ctypes.c_float(1.123456).value)
- assertCollectSuccess('d', sys.float_info.max)
- assertCollectSuccess('d', sys.float_info.min)
- assertCollectSuccess('d', sys.float_info.epsilon)
-
- # supported signed int types
- #
- # The size of C types changes with implementation, we need to make sure
- # that there is no overflow error on the platform running this test.
- supported_signed_int_types = list(
- set(_array_signed_int_typecode_ctype_mappings.keys())
- .intersection(set(_array_type_mappings.keys())))
- for t in supported_signed_int_types:
- ctype = _array_signed_int_typecode_ctype_mappings[t]
- max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
- assertCollectSuccess(t, max_val - 1)
- assertCollectSuccess(t, -max_val)
-
- # supported unsigned int types
- #
- # JVM does not have unsigned types. We need to be very careful to make
- # sure that there is no overflow error.
- supported_unsigned_int_types = list(
- set(_array_unsigned_int_typecode_ctype_mappings.keys())
- .intersection(set(_array_type_mappings.keys())))
- for t in supported_unsigned_int_types:
- ctype = _array_unsigned_int_typecode_ctype_mappings[t]
- assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1)
-
- # all supported types
- #
- # Make sure the types tested above:
- # 1. are all supported types
- # 2. cover all supported types
- supported_types = (supported_string_types +
- supported_fractional_types +
- supported_signed_int_types +
- supported_unsigned_int_types)
- self.assertEqual(set(supported_types), set(_array_type_mappings.keys()))
-
- # all unsupported types
- #
- # Keys in _array_type_mappings is a complete list of all supported types,
- # and types not in _array_type_mappings are considered unsupported.
- # `array.typecodes` are not supported in python 2.
- if sys.version_info[0] < 3:
- all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd'])
- else:
- all_types = set(array.typecodes)
- unsupported_types = all_types - set(supported_types)
- # test unsupported types
- for t in unsupported_types:
- with self.assertRaises(TypeError):
- a = array.array(t)
- self.spark.createDataFrame([Row(myarray=a)]).collect()
-
- def test_bucketed_write(self):
- data = [
- (1, "foo", 3.0), (2, "foo", 5.0),
- (3, "bar", -1.0), (4, "bar", 6.0),
- ]
- df = self.spark.createDataFrame(data, ["x", "y", "z"])
-
- def count_bucketed_cols(names, table="pyspark_bucket"):
- """Given a sequence of column names and a table name
- query the catalog and return number o columns which are
- used for bucketing
- """
- cols = self.spark.catalog.listColumns(table)
- num = len([c for c in cols if c.name in names and c.isBucket])
- return num
-
- with self.table("pyspark_bucket"):
- # Test write with one bucketing column
- df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
- self.assertEqual(count_bucketed_cols(["x"]), 1)
- self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
-
- # Test write two bucketing columns
- df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
- self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
- self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
-
- # Test write with bucket and sort
- df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
- self.assertEqual(count_bucketed_cols(["x"]), 1)
- self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
-
- # Test write with a list of columns
- df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
- self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
- self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
-
- # Test write with bucket and sort with a list of columns
- (df.write.bucketBy(2, "x")
- .sortBy(["y", "z"])
- .mode("overwrite").saveAsTable("pyspark_bucket"))
- self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
-
- # Test write with bucket and sort with multiple columns
- (df.write.bucketBy(2, "x")
- .sortBy("y", "z")
- .mode("overwrite").saveAsTable("pyspark_bucket"))
- self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
-
- def _to_pandas(self):
- from datetime import datetime, date
- schema = StructType().add("a", IntegerType()).add("b", StringType())\
- .add("c", BooleanType()).add("d", FloatType())\
- .add("dt", DateType()).add("ts", TimestampType())
- data = [
- (1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
- (2, "foo", True, 5.0, None, None),
- (3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)),
- (4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)),
- ]
- df = self.spark.createDataFrame(data, schema)
- return df.toPandas()
-
- @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
- def test_to_pandas(self):
- import numpy as np
- pdf = self._to_pandas()
- types = pdf.dtypes
- self.assertEquals(types[0], np.int32)
- self.assertEquals(types[1], np.object)
- self.assertEquals(types[2], np.bool)
- self.assertEquals(types[3], np.float32)
- self.assertEquals(types[4], np.object) # datetime.date
- self.assertEquals(types[5], 'datetime64[ns]')
-
- @unittest.skipIf(_have_pandas, "Required Pandas was found.")
- def test_to_pandas_required_pandas_not_found(self):
- with QuietTest(self.sc):
- with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
- self._to_pandas()
-
- @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
- def test_to_pandas_avoid_astype(self):
- import numpy as np
- schema = StructType().add("a", IntegerType()).add("b", StringType())\
- .add("c", IntegerType())
- data = [(1, "foo", 16777220), (None, "bar", None)]
- df = self.spark.createDataFrame(data, schema)
- types = df.toPandas().dtypes
- self.assertEquals(types[0], np.float64) # doesn't convert to np.int32 due to NaN value.
- self.assertEquals(types[1], np.object)
- self.assertEquals(types[2], np.float64)
-
- def test_create_dataframe_from_array_of_long(self):
- import array
- data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))]
- df = self.spark.createDataFrame(data)
- self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
-
- @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
- def test_create_dataframe_from_pandas_with_timestamp(self):
- import pandas as pd
- from datetime import datetime
- pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
- "d": [pd.Timestamp.now().date()]})[["d", "ts"]]
- # test types are inferred correctly without specifying schema
- df = self.spark.createDataFrame(pdf)
- self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
- self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
- # test with schema will accept pdf as input
- df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
- self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
- self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
-
- @unittest.skipIf(_have_pandas, "Required Pandas was found.")
- def test_create_dataframe_required_pandas_not_found(self):
- with QuietTest(self.sc):
- with self.assertRaisesRegexp(
- ImportError,
- "(Pandas >= .* must be installed|No module named '?pandas'?)"):
- import pandas as pd
- from datetime import datetime
- pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
- "d": [pd.Timestamp.now().date()]})
- self.spark.createDataFrame(pdf)
-
- # Regression test for SPARK-23360
- @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
- def test_create_dateframe_from_pandas_with_dst(self):
- import pandas as pd
- from datetime import datetime
-
- pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]})
-
- df = self.spark.createDataFrame(pdf)
- self.assertPandasEqual(pdf, df.toPandas())
-
- orig_env_tz = os.environ.get('TZ', None)
- try:
- tz = 'America/Los_Angeles'
- os.environ['TZ'] = tz
- time.tzset()
- with self.sql_conf({'spark.sql.session.timeZone': tz}):
- df = self.spark.createDataFrame(pdf)
- self.assertPandasEqual(pdf, df.toPandas())
- finally:
- del os.environ['TZ']
- if orig_env_tz is not None:
- os.environ['TZ'] = orig_env_tz
- time.tzset()
-
- def test_sort_with_nulls_order(self):
- from pyspark.sql import functions
-
- df = self.spark.createDataFrame(
- [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"])
- self.assertEquals(
- df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(),
- [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')])
- self.assertEquals(
- df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(),
- [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)])
- self.assertEquals(
- df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(),
- [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')])
- self.assertEquals(
- df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
- [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
-
- def test_json_sampling_ratio(self):
- rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
- .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x))
- schema = self.spark.read.option('inferSchema', True) \
- .option('samplingRatio', 0.5) \
- .json(rdd).schema
- self.assertEquals(schema, StructType([StructField("a", LongType(), True)]))
-
- def test_csv_sampling_ratio(self):
- rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
- .map(lambda x: '0.1' if x == 1 else str(x))
- schema = self.spark.read.option('inferSchema', True)\
- .csv(rdd, samplingRatio=0.5).schema
- self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))
-
- def test_checking_csv_header(self):
- path = tempfile.mkdtemp()
- shutil.rmtree(path)
- try:
- self.spark.createDataFrame([[1, 1000], [2000, 2]])\
- .toDF('f1', 'f2').write.option("header", "true").csv(path)
- schema = StructType([
- StructField('f2', IntegerType(), nullable=True),
- StructField('f1', IntegerType(), nullable=True)])
- df = self.spark.read.option('header', 'true').schema(schema)\
- .csv(path, enforceSchema=False)
- self.assertRaisesRegexp(
- Exception,
- "CSV header does not conform to the schema",
- lambda: df.collect())
- finally:
- shutil.rmtree(path)
-
- def test_ignore_column_of_all_nulls(self):
- path = tempfile.mkdtemp()
- shutil.rmtree(path)
- try:
- df = self.spark.createDataFrame([["""{"a":null, "b":1, "c":3.0}"""],
- ["""{"a":null, "b":null, "c":"string"}"""],
- ["""{"a":null, "b":null, "c":null}"""]])
- df.write.text(path)
- schema = StructType([
- StructField('b', LongType(), nullable=True),
- StructField('c', StringType(), nullable=True)])
- readback = self.spark.read.json(path, dropFieldIfAllNull=True)
- self.assertEquals(readback.schema, schema)
- finally:
- shutil.rmtree(path)
-
- # SPARK-24721
- @unittest.skipIf(not _test_compiled, _test_not_compiled_message)
- def test_datasource_with_udf(self):
- from pyspark.sql.functions import udf, lit, col
-
- path = tempfile.mkdtemp()
- shutil.rmtree(path)
-
- try:
- self.spark.range(1).write.mode("overwrite").format('csv').save(path)
- filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
- datasource_df = self.spark.read \
- .format("org.apache.spark.sql.sources.SimpleScanSource") \
- .option('from', 0).option('to', 1).load().toDF('i')
- datasource_v2_df = self.spark.read \
- .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
- .load().toDF('i', 'j')
-
- c1 = udf(lambda x: x + 1, 'int')(lit(1))
- c2 = udf(lambda x: x + 1, 'int')(col('i'))
-
- f1 = udf(lambda x: False, 'boolean')(lit(1))
- f2 = udf(lambda x: False, 'boolean')(col('i'))
-
- for df in [filesource_df, datasource_df, datasource_v2_df]:
- result = df.withColumn('c', c1)
- expected = df.withColumn('c', lit(2))
- self.assertEquals(expected.collect(), result.collect())
-
- for df in [filesource_df, datasource_df, datasource_v2_df]:
- result = df.withColumn('c', c2)
- expected = df.withColumn('c', col('i') + 1)
- self.assertEquals(expected.collect(), result.collect())
-
- for df in [filesource_df, datasource_df, datasource_v2_df]:
- for f in [f1, f2]:
- result = df.filter(f)
- self.assertEquals(0, result.count())
- finally:
- shutil.rmtree(path)
-
- def test_repr_behaviors(self):
- import re
- pattern = re.compile(r'^ *\|', re.MULTILINE)
- df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
-
- # test when eager evaluation is enabled and _repr_html_ will not be called
- with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
- expected1 = """+-----+-----+
- || key|value|
- |+-----+-----+
- || 1| 1|
- ||22222|22222|
- |+-----+-----+
- |"""
- self.assertEquals(re.sub(pattern, '', expected1), df.__repr__())
- with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
- expected2 = """+---+-----+
- ||key|value|
- |+---+-----+
- || 1| 1|
- ||222| 222|
- |+---+-----+
- |"""
- self.assertEquals(re.sub(pattern, '', expected2), df.__repr__())
- with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
- expected3 = """+---+-----+
- ||key|value|
- |+---+-----+
- || 1| 1|
- |+---+-----+
- |only showing top 1 row
- |"""
- self.assertEquals(re.sub(pattern, '', expected3), df.__repr__())
-
- # test when eager evaluation is enabled and _repr_html_ will be called
- with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
- expected1 = """
+ |only showing top 1 row
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
+
+ # test when eager evaluation is disabled and _repr_html_ will be called
+ with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
+ expected = "DataFrame[key: bigint, value: string]"
+ self.assertEquals(None, df._repr_html_())
+ self.assertEquals(expected, df.__repr__())
+ with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
+ self.assertEquals(None, df._repr_html_())
+ self.assertEquals(expected, df.__repr__())
+ with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
+ self.assertEquals(None, df._repr_html_())
+ self.assertEquals(expected, df.__repr__())
+
+
+class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
+ # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
+ # static and immutable. This can't be set or unset, for example, via `spark.conf`.
+
+ @classmethod
+ def setUpClass(cls):
+ import glob
+ from pyspark.find_spark_home import _find_spark_home
+
+ SPARK_HOME = _find_spark_home()
+ filename_pattern = (
+ "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
+ "TestQueryExecutionListener.class")
+ cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern)))
+
+ if cls.has_listener:
+ # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
+ cls.spark = SparkSession.builder \
+ .master("local[4]") \
+ .appName(cls.__name__) \
+ .config(
+ "spark.sql.queryExecutionListeners",
+ "org.apache.spark.sql.TestQueryExecutionListener") \
+ .getOrCreate()
+
+ def setUp(self):
+ if not self.has_listener:
+ raise self.skipTest(
+ "'org.apache.spark.sql.TestQueryExecutionListener' is not "
+ "available. Will skip the related tests.")
+
+ @classmethod
+ def tearDownClass(cls):
+ if hasattr(cls, "spark"):
+ cls.spark.stop()
+
+ def tearDown(self):
+ self.spark._jvm.OnSuccessCall.clear()
+
+ def test_query_execution_listener_on_collect(self):
+ self.assertFalse(
+ self.spark._jvm.OnSuccessCall.isCalled(),
+ "The callback from the query execution listener should not be called before 'collect'")
+ self.spark.sql("SELECT * FROM range(1)").collect()
+ self.assertTrue(
+ self.spark._jvm.OnSuccessCall.isCalled(),
+ "The callback from the query execution listener should be called after 'collect'")
+
+ @unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message)
+ def test_query_execution_listener_on_collect_with_arrow(self):
+ with self.sql_conf({"spark.sql.execution.arrow.enabled": True}):
+ self.assertFalse(
+ self.spark._jvm.OnSuccessCall.isCalled(),
+ "The callback from the query execution listener should not be "
+ "called before 'toPandas'")
+ self.spark.sql("SELECT * FROM range(1)").toPandas()
+ self.assertTrue(
+ self.spark._jvm.OnSuccessCall.isCalled(),
+ "The callback from the query execution listener should be called after 'toPandas'")
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_dataframe import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py
new file mode 100644
index 0000000000000..5579620bc2be1
--- /dev/null
+++ b/python/pyspark/sql/tests/test_datasources.py
@@ -0,0 +1,171 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import shutil
+import tempfile
+
+from pyspark.sql import Row
+from pyspark.sql.types import *
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class DataSourcesTests(ReusedSQLTestCase):
+
+ def test_linesep_text(self):
+ df = self.spark.read.text("python/test_support/sql/ages_newlines.csv", lineSep=",")
+ expected = [Row(value=u'Joe'), Row(value=u'20'), Row(value=u'"Hi'),
+ Row(value=u'\nI am Jeo"\nTom'), Row(value=u'30'),
+ Row(value=u'"My name is Tom"\nHyukjin'), Row(value=u'25'),
+ Row(value=u'"I am Hyukjin\n\nI love Spark!"\n')]
+ self.assertEqual(df.collect(), expected)
+
+ tpath = tempfile.mkdtemp()
+ shutil.rmtree(tpath)
+ try:
+ df.write.text(tpath, lineSep="!")
+ expected = [Row(value=u'Joe!20!"Hi!'), Row(value=u'I am Jeo"'),
+ Row(value=u'Tom!30!"My name is Tom"'),
+ Row(value=u'Hyukjin!25!"I am Hyukjin'),
+ Row(value=u''), Row(value=u'I love Spark!"'),
+ Row(value=u'!')]
+ readback = self.spark.read.text(tpath)
+ self.assertEqual(readback.collect(), expected)
+ finally:
+ shutil.rmtree(tpath)
+
+ def test_multiline_json(self):
+ people1 = self.spark.read.json("python/test_support/sql/people.json")
+ people_array = self.spark.read.json("python/test_support/sql/people_array.json",
+ multiLine=True)
+ self.assertEqual(people1.collect(), people_array.collect())
+
+ def test_encoding_json(self):
+ people_array = self.spark.read\
+ .json("python/test_support/sql/people_array_utf16le.json",
+ multiLine=True, encoding="UTF-16LE")
+ expected = [Row(age=30, name=u'Andy'), Row(age=19, name=u'Justin')]
+ self.assertEqual(people_array.collect(), expected)
+
+ def test_linesep_json(self):
+ df = self.spark.read.json("python/test_support/sql/people.json", lineSep=",")
+ expected = [Row(_corrupt_record=None, name=u'Michael'),
+ Row(_corrupt_record=u' "age":30}\n{"name":"Justin"', name=None),
+ Row(_corrupt_record=u' "age":19}\n', name=None)]
+ self.assertEqual(df.collect(), expected)
+
+ tpath = tempfile.mkdtemp()
+ shutil.rmtree(tpath)
+ try:
+ df = self.spark.read.json("python/test_support/sql/people.json")
+ df.write.json(tpath, lineSep="!!")
+ readback = self.spark.read.json(tpath, lineSep="!!")
+ self.assertEqual(readback.collect(), df.collect())
+ finally:
+ shutil.rmtree(tpath)
+
+ def test_multiline_csv(self):
+ ages_newlines = self.spark.read.csv(
+ "python/test_support/sql/ages_newlines.csv", multiLine=True)
+ expected = [Row(_c0=u'Joe', _c1=u'20', _c2=u'Hi,\nI am Jeo'),
+ Row(_c0=u'Tom', _c1=u'30', _c2=u'My name is Tom'),
+ Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')]
+ self.assertEqual(ages_newlines.collect(), expected)
+
+ def test_ignorewhitespace_csv(self):
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.spark.createDataFrame([[" a", "b ", " c "]]).write.csv(
+ tmpPath,
+ ignoreLeadingWhiteSpace=False,
+ ignoreTrailingWhiteSpace=False)
+
+ expected = [Row(value=u' a,b , c ')]
+ readback = self.spark.read.text(tmpPath)
+ self.assertEqual(readback.collect(), expected)
+ shutil.rmtree(tmpPath)
+
+ def test_read_multiple_orc_file(self):
+ df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0",
+ "python/test_support/sql/orc_partitioned/b=1/c=1"])
+ self.assertEqual(2, df.count())
+
+ def test_read_text_file_list(self):
+ df = self.spark.read.text(['python/test_support/sql/text-test.txt',
+ 'python/test_support/sql/text-test.txt'])
+ count = df.count()
+ self.assertEquals(count, 4)
+
+ def test_json_sampling_ratio(self):
+ rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
+ .map(lambda x: '{"a":0.1}' if x == 1 else '{"a":%s}' % str(x))
+ schema = self.spark.read.option('inferSchema', True) \
+ .option('samplingRatio', 0.5) \
+ .json(rdd).schema
+ self.assertEquals(schema, StructType([StructField("a", LongType(), True)]))
+
+ def test_csv_sampling_ratio(self):
+ rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
+ .map(lambda x: '0.1' if x == 1 else str(x))
+ schema = self.spark.read.option('inferSchema', True)\
+ .csv(rdd, samplingRatio=0.5).schema
+ self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))
+
+ def test_checking_csv_header(self):
+ path = tempfile.mkdtemp()
+ shutil.rmtree(path)
+ try:
+ self.spark.createDataFrame([[1, 1000], [2000, 2]])\
+ .toDF('f1', 'f2').write.option("header", "true").csv(path)
+ schema = StructType([
+ StructField('f2', IntegerType(), nullable=True),
+ StructField('f1', IntegerType(), nullable=True)])
+ df = self.spark.read.option('header', 'true').schema(schema)\
+ .csv(path, enforceSchema=False)
+ self.assertRaisesRegexp(
+ Exception,
+ "CSV header does not conform to the schema",
+ lambda: df.collect())
+ finally:
+ shutil.rmtree(path)
+
+ def test_ignore_column_of_all_nulls(self):
+ path = tempfile.mkdtemp()
+ shutil.rmtree(path)
+ try:
+ df = self.spark.createDataFrame([["""{"a":null, "b":1, "c":3.0}"""],
+ ["""{"a":null, "b":null, "c":"string"}"""],
+ ["""{"a":null, "b":null, "c":null}"""]])
+ df.write.text(path)
+ schema = StructType([
+ StructField('b', LongType(), nullable=True),
+ StructField('c', StringType(), nullable=True)])
+ readback = self.spark.read.json(path, dropFieldIfAllNull=True)
+ self.assertEquals(readback.schema, schema)
+ finally:
+ shutil.rmtree(path)
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.test_datasources import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
new file mode 100644
index 0000000000000..fe6660272e323
--- /dev/null
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -0,0 +1,279 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+import sys
+
+from pyspark.sql import Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class FunctionsTests(ReusedSQLTestCase):
+
+ def test_explode(self):
+ from pyspark.sql.functions import explode, explode_outer, posexplode_outer
+ d = [
+ Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
+ Row(a=1, intlist=[], mapfield={}),
+ Row(a=1, intlist=None, mapfield=None),
+ ]
+ rdd = self.sc.parallelize(d)
+ data = self.spark.createDataFrame(rdd)
+
+ result = data.select(explode(data.intlist).alias("a")).select("a").collect()
+ self.assertEqual(result[0][0], 1)
+ self.assertEqual(result[1][0], 2)
+ self.assertEqual(result[2][0], 3)
+
+ result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
+ self.assertEqual(result[0][0], "a")
+ self.assertEqual(result[0][1], "b")
+
+ result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
+ self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])
+
+ result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
+ self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)])
+
+ result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
+ self.assertEqual(result, [1, 2, 3, None, None])
+
+ result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
+ self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])
+
+ def test_basic_functions(self):
+ rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+ df = self.spark.read.json(rdd)
+ df.count()
+ df.collect()
+ df.schema
+
+ # cache and checkpoint
+ self.assertFalse(df.is_cached)
+ df.persist()
+ df.unpersist(True)
+ df.cache()
+ self.assertTrue(df.is_cached)
+ self.assertEqual(2, df.count())
+
+ with self.tempView("temp"):
+ df.createOrReplaceTempView("temp")
+ df = self.spark.sql("select foo from temp")
+ df.count()
+ df.collect()
+
+ def test_corr(self):
+ import math
+ df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
+ corr = df.stat.corr(u"a", "b")
+ self.assertTrue(abs(corr - 0.95734012) < 1e-6)
+
+ def test_sampleby(self):
+ df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF()
+ sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
+ self.assertTrue(sampled.count() == 3)
+
+ def test_cov(self):
+ df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
+ cov = df.stat.cov(u"a", "b")
+ self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
+
+ def test_crosstab(self):
+ df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
+ ct = df.stat.crosstab(u"a", "b").collect()
+ ct = sorted(ct, key=lambda x: x[0])
+ for i, row in enumerate(ct):
+ self.assertEqual(row[0], str(i))
+ self.assertTrue(row[1], 1)
+ self.assertTrue(row[2], 1)
+
+ def test_math_functions(self):
+ df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
+ from pyspark.sql import functions
+ import math
+
+ def get_values(l):
+ return [j[0] for j in l]
+
+ def assert_close(a, b):
+ c = get_values(b)
+ diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
+ return sum(diff) == len(a)
+ assert_close([math.cos(i) for i in range(10)],
+ df.select(functions.cos(df.a)).collect())
+ assert_close([math.cos(i) for i in range(10)],
+ df.select(functions.cos("a")).collect())
+ assert_close([math.sin(i) for i in range(10)],
+ df.select(functions.sin(df.a)).collect())
+ assert_close([math.sin(i) for i in range(10)],
+ df.select(functions.sin(df['a'])).collect())
+ assert_close([math.pow(i, 2 * i) for i in range(10)],
+ df.select(functions.pow(df.a, df.b)).collect())
+ assert_close([math.pow(i, 2) for i in range(10)],
+ df.select(functions.pow(df.a, 2)).collect())
+ assert_close([math.pow(i, 2) for i in range(10)],
+ df.select(functions.pow(df.a, 2.0)).collect())
+ assert_close([math.hypot(i, 2 * i) for i in range(10)],
+ df.select(functions.hypot(df.a, df.b)).collect())
+
+ def test_rand_functions(self):
+ df = self.df
+ from pyspark.sql import functions
+ rnd = df.select('key', functions.rand()).collect()
+ for row in rnd:
+ assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
+ rndn = df.select('key', functions.randn(5)).collect()
+ for row in rndn:
+ assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
+
+ # If the specified seed is 0, we should use it.
+ # https://issues.apache.org/jira/browse/SPARK-9691
+ rnd1 = df.select('key', functions.rand(0)).collect()
+ rnd2 = df.select('key', functions.rand(0)).collect()
+ self.assertEqual(sorted(rnd1), sorted(rnd2))
+
+ rndn1 = df.select('key', functions.randn(0)).collect()
+ rndn2 = df.select('key', functions.randn(0)).collect()
+ self.assertEqual(sorted(rndn1), sorted(rndn2))
+
+ def test_string_functions(self):
+ from pyspark.sql.functions import col, lit
+ df = self.spark.createDataFrame([['nick']], schema=['name'])
+ self.assertRaisesRegexp(
+ TypeError,
+ "must be the same type",
+ lambda: df.select(col('name').substr(0, lit(1))))
+ if sys.version_info.major == 2:
+ self.assertRaises(
+ TypeError,
+ lambda: df.select(col('name').substr(long(0), long(1))))
+
+ def test_array_contains_function(self):
+ from pyspark.sql.functions import array_contains
+
+ df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data'])
+ actual = df.select(array_contains(df.data, "1").alias('b')).collect()
+ self.assertEqual([Row(b=True), Row(b=False)], actual)
+
+ def test_between_function(self):
+ df = self.sc.parallelize([
+ Row(a=1, b=2, c=3),
+ Row(a=2, b=1, c=3),
+ Row(a=4, b=1, c=4)]).toDF()
+ self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
+ df.filter(df.a.between(df.b, df.c)).collect())
+
+ def test_dayofweek(self):
+ from pyspark.sql.functions import dayofweek
+ dt = datetime.datetime(2017, 11, 6)
+ df = self.spark.createDataFrame([Row(date=dt)])
+ row = df.select(dayofweek(df.date)).first()
+ self.assertEqual(row[0], 2)
+
+ def test_expr(self):
+ from pyspark.sql import functions
+ row = Row(a="length string", b=75)
+ df = self.spark.createDataFrame([row])
+ result = df.select(functions.expr("length(a)")).collect()[0].asDict()
+ self.assertEqual(13, result["length(a)"])
+
+ # add test for SPARK-10577 (test broadcast join hint)
+ def test_functions_broadcast(self):
+ from pyspark.sql.functions import broadcast
+
+ df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
+ df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
+
+ # equijoin - should be converted into broadcast join
+ plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
+ self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
+
+ # no join key -- should not be a broadcast join
+ plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan()
+ self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
+
+ # planner should not crash without a join
+ broadcast(df1)._jdf.queryExecution().executedPlan()
+
+ def test_first_last_ignorenulls(self):
+ from pyspark.sql import functions
+ df = self.spark.range(0, 100)
+ df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
+ df3 = df2.select(functions.first(df2.id, False).alias('a'),
+ functions.first(df2.id, True).alias('b'),
+ functions.last(df2.id, False).alias('c'),
+ functions.last(df2.id, True).alias('d'))
+ self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
+
+ def test_approxQuantile(self):
+ df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
+ for f in ["a", u"a"]:
+ aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
+ self.assertTrue(isinstance(aq, list))
+ self.assertEqual(len(aq), 3)
+ self.assertTrue(all(isinstance(q, float) for q in aq))
+ aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
+ self.assertTrue(isinstance(aqs, list))
+ self.assertEqual(len(aqs), 2)
+ self.assertTrue(isinstance(aqs[0], list))
+ self.assertEqual(len(aqs[0]), 3)
+ self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
+ self.assertTrue(isinstance(aqs[1], list))
+ self.assertEqual(len(aqs[1]), 3)
+ self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
+ aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
+ self.assertTrue(isinstance(aqt, list))
+ self.assertEqual(len(aqt), 2)
+ self.assertTrue(isinstance(aqt[0], list))
+ self.assertEqual(len(aqt[0]), 3)
+ self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
+ self.assertTrue(isinstance(aqt[1], list))
+ self.assertEqual(len(aqt[1]), 3)
+ self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
+ self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
+ self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
+ self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
+
+ def test_sort_with_nulls_order(self):
+ from pyspark.sql import functions
+
+ df = self.spark.createDataFrame(
+ [('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"])
+ self.assertEquals(
+ df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(),
+ [Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')])
+ self.assertEquals(
+ df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(),
+ [Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)])
+ self.assertEquals(
+ df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(),
+ [Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')])
+ self.assertEquals(
+ df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
+ [Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.test_functions import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py
new file mode 100644
index 0000000000000..6de1b8ea0b3ce
--- /dev/null
+++ b/python/pyspark/sql/tests/test_group.py
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql import Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class GroupTests(ReusedSQLTestCase):
+
+ def test_aggregator(self):
+ df = self.df
+ g = df.groupBy()
+ self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+ self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+
+ from pyspark.sql import functions
+ self.assertEqual((0, u'99'),
+ tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
+ self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0])
+ self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.test_group import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py
new file mode 100644
index 0000000000000..c4b5478a7e893
--- /dev/null
+++ b/python/pyspark/sql/tests/test_pandas_udf.py
@@ -0,0 +1,217 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.types import *
+from pyspark.sql.utils import ParseException
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
+ pandas_requirement_message, pyarrow_requirement_message
+from pyspark.testing.utils import QuietTest
+
+
+@unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message)
+class PandasUDFTests(ReusedSQLTestCase):
+
+ def test_pandas_udf_basic(self):
+ from pyspark.rdd import PythonEvalType
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ udf = pandas_udf(lambda x: x, DoubleType())
+ self.assertEqual(udf.returnType, DoubleType())
+ self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+
+ udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR)
+ self.assertEqual(udf.returnType, DoubleType())
+ self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+
+ udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR)
+ self.assertEqual(udf.returnType, DoubleType())
+ self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+
+ udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]),
+ PandasUDFType.GROUPED_MAP)
+ self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
+ self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+
+ udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
+ self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
+ self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+
+ udf = pandas_udf(lambda x: x, 'v double',
+ functionType=PandasUDFType.GROUPED_MAP)
+ self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
+ self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+
+ udf = pandas_udf(lambda x: x, returnType='v double',
+ functionType=PandasUDFType.GROUPED_MAP)
+ self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
+ self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+
+ def test_pandas_udf_decorator(self):
+ from pyspark.rdd import PythonEvalType
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+ from pyspark.sql.types import StructType, StructField, DoubleType
+
+ @pandas_udf(DoubleType())
+ def foo(x):
+ return x
+ self.assertEqual(foo.returnType, DoubleType())
+ self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+
+ @pandas_udf(returnType=DoubleType())
+ def foo(x):
+ return x
+ self.assertEqual(foo.returnType, DoubleType())
+ self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+
+ schema = StructType([StructField("v", DoubleType())])
+
+ @pandas_udf(schema, PandasUDFType.GROUPED_MAP)
+ def foo(x):
+ return x
+ self.assertEqual(foo.returnType, schema)
+ self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+
+ @pandas_udf('v double', PandasUDFType.GROUPED_MAP)
+ def foo(x):
+ return x
+ self.assertEqual(foo.returnType, schema)
+ self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+
+ @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
+ def foo(x):
+ return x
+ self.assertEqual(foo.returnType, schema)
+ self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+
+ @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR)
+ def foo(x):
+ return x
+ self.assertEqual(foo.returnType, DoubleType())
+ self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+
+ @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
+ def foo(x):
+ return x
+ self.assertEqual(foo.returnType, schema)
+ self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+
+ def test_udf_wrong_arg(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ with QuietTest(self.sc):
+ with self.assertRaises(ParseException):
+ @pandas_udf('blah')
+ def foo(x):
+ return x
+ with self.assertRaisesRegexp(ValueError, 'Invalid returnType.*None'):
+ @pandas_udf(functionType=PandasUDFType.SCALAR)
+ def foo(x):
+ return x
+ with self.assertRaisesRegexp(ValueError, 'Invalid functionType'):
+ @pandas_udf('double', 100)
+ def foo(x):
+ return x
+
+ with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
+ pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR)
+ with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
+ @pandas_udf(LongType(), PandasUDFType.SCALAR)
+ def zero_with_type():
+ return 1
+
+ with self.assertRaisesRegexp(TypeError, 'Invalid returnType'):
+ @pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
+ def foo(df):
+ return df
+ with self.assertRaisesRegexp(TypeError, 'Invalid returnType'):
+ @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP)
+ def foo(df):
+ return df
+ with self.assertRaisesRegexp(ValueError, 'Invalid function'):
+ @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
+ def foo(k, v, w):
+ return k
+
+ def test_stopiteration_in_udf(self):
+ from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
+ from py4j.protocol import Py4JJavaError
+
+ def foo(x):
+ raise StopIteration()
+
+ def foofoo(x, y):
+ raise StopIteration()
+
+ exc_message = "Caught StopIteration thrown from user's code; failing the task"
+ df = self.spark.range(0, 100)
+
+ # plain udf (test for SPARK-23754)
+ self.assertRaisesRegexp(
+ Py4JJavaError,
+ exc_message,
+ df.withColumn('v', udf(foo)('id')).collect
+ )
+
+ # pandas scalar udf
+ self.assertRaisesRegexp(
+ Py4JJavaError,
+ exc_message,
+ df.withColumn(
+ 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
+ ).collect
+ )
+
+ # pandas grouped map
+ self.assertRaisesRegexp(
+ Py4JJavaError,
+ exc_message,
+ df.groupBy('id').apply(
+ pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
+ ).collect
+ )
+
+ self.assertRaisesRegexp(
+ Py4JJavaError,
+ exc_message,
+ df.groupBy('id').apply(
+ pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP)
+ ).collect
+ )
+
+ # pandas grouped agg
+ self.assertRaisesRegexp(
+ Py4JJavaError,
+ exc_message,
+ df.groupBy('id').agg(
+ pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id')
+ ).collect
+ )
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_pandas_udf import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
new file mode 100644
index 0000000000000..5383704434c85
--- /dev/null
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
@@ -0,0 +1,504 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.types import *
+from pyspark.sql.utils import AnalysisException
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
+ pandas_requirement_message, pyarrow_requirement_message
+from pyspark.testing.utils import QuietTest
+
+
+@unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message)
+class GroupedAggPandasUDFTests(ReusedSQLTestCase):
+
+ @property
+ def data(self):
+ from pyspark.sql.functions import array, explode, col, lit
+ return self.spark.range(10).toDF('id') \
+ .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
+ .withColumn("v", explode(col('vs'))) \
+ .drop('vs') \
+ .withColumn('w', lit(1.0))
+
+ @property
+ def python_plus_one(self):
+ from pyspark.sql.functions import udf
+
+ @udf('double')
+ def plus_one(v):
+ assert isinstance(v, (int, float))
+ return v + 1
+ return plus_one
+
+ @property
+ def pandas_scalar_plus_two(self):
+ import pandas as pd
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.SCALAR)
+ def plus_two(v):
+ assert isinstance(v, pd.Series)
+ return v + 2
+ return plus_two
+
+ @property
+ def pandas_agg_mean_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def avg(v):
+ return v.mean()
+ return avg
+
+ @property
+ def pandas_agg_sum_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def sum(v):
+ return v.sum()
+ return sum
+
+ @property
+ def pandas_agg_weighted_mean_udf(self):
+ import numpy as np
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def weighted_mean(v, w):
+ return np.average(v, weights=w)
+ return weighted_mean
+
+ def test_manual(self):
+ from pyspark.sql.functions import pandas_udf, array
+
+ df = self.data
+ sum_udf = self.pandas_agg_sum_udf
+ mean_udf = self.pandas_agg_mean_udf
+ mean_arr_udf = pandas_udf(
+ self.pandas_agg_mean_udf.func,
+ ArrayType(self.pandas_agg_mean_udf.returnType),
+ self.pandas_agg_mean_udf.evalType)
+
+ result1 = df.groupby('id').agg(
+ sum_udf(df.v),
+ mean_udf(df.v),
+ mean_arr_udf(array(df.v))).sort('id')
+ expected1 = self.spark.createDataFrame(
+ [[0, 245.0, 24.5, [24.5]],
+ [1, 255.0, 25.5, [25.5]],
+ [2, 265.0, 26.5, [26.5]],
+ [3, 275.0, 27.5, [27.5]],
+ [4, 285.0, 28.5, [28.5]],
+ [5, 295.0, 29.5, [29.5]],
+ [6, 305.0, 30.5, [30.5]],
+ [7, 315.0, 31.5, [31.5]],
+ [8, 325.0, 32.5, [32.5]],
+ [9, 335.0, 33.5, [33.5]]],
+ ['id', 'sum(v)', 'avg(v)', 'avg(array(v))'])
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_basic(self):
+ from pyspark.sql.functions import col, lit, mean
+
+ df = self.data
+ weighted_mean_udf = self.pandas_agg_weighted_mean_udf
+
+ # Groupby one column and aggregate one UDF with literal
+ result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id')
+ expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ # Groupby one expression and aggregate one UDF with literal
+ result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\
+ .sort(df.id + 1)
+ expected2 = df.groupby((col('id') + 1))\
+ .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1)
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+ # Groupby one column and aggregate one UDF without literal
+ result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id')
+ expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id')
+ self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+
+ # Groupby one expression and aggregate one UDF without literal
+ result4 = df.groupby((col('id') + 1).alias('id'))\
+ .agg(weighted_mean_udf(df.v, df.w))\
+ .sort('id')
+ expected4 = df.groupby((col('id') + 1).alias('id'))\
+ .agg(mean(df.v).alias('weighted_mean(v, w)'))\
+ .sort('id')
+ self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+
+ def test_unsupported_types(self):
+ from pyspark.sql.types import DoubleType, MapType
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
+ pandas_udf(
+ lambda x: x,
+ ArrayType(ArrayType(TimestampType())),
+ PandasUDFType.GROUPED_AGG)
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
+ @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG)
+ def mean_and_std_udf(v):
+ return v.mean(), v.std()
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
+ @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG)
+ def mean_and_std_udf(v):
+ return {v.mean(): v.std()}
+
+ def test_alias(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias'))
+ expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias'))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_mixed_sql(self):
+ """
+ Test mixing group aggregate pandas UDF with sql expression.
+ """
+ from pyspark.sql.functions import sum
+
+ df = self.data
+ sum_udf = self.pandas_agg_sum_udf
+
+ # Mix group aggregate pandas UDF with sql expression
+ result1 = (df.groupby('id')
+ .agg(sum_udf(df.v) + 1)
+ .sort('id'))
+ expected1 = (df.groupby('id')
+ .agg(sum(df.v) + 1)
+ .sort('id'))
+
+ # Mix group aggregate pandas UDF with sql expression (order swapped)
+ result2 = (df.groupby('id')
+ .agg(sum_udf(df.v + 1))
+ .sort('id'))
+
+ expected2 = (df.groupby('id')
+ .agg(sum(df.v + 1))
+ .sort('id'))
+
+ # Wrap group aggregate pandas UDF with two sql expressions
+ result3 = (df.groupby('id')
+ .agg(sum_udf(df.v + 1) + 2)
+ .sort('id'))
+ expected3 = (df.groupby('id')
+ .agg(sum(df.v + 1) + 2)
+ .sort('id'))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+ self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+
+ def test_mixed_udfs(self):
+ """
+ Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF.
+ """
+ from pyspark.sql.functions import sum
+
+ df = self.data
+ plus_one = self.python_plus_one
+ plus_two = self.pandas_scalar_plus_two
+ sum_udf = self.pandas_agg_sum_udf
+
+ # Mix group aggregate pandas UDF and python UDF
+ result1 = (df.groupby('id')
+ .agg(plus_one(sum_udf(df.v)))
+ .sort('id'))
+ expected1 = (df.groupby('id')
+ .agg(plus_one(sum(df.v)))
+ .sort('id'))
+
+ # Mix group aggregate pandas UDF and python UDF (order swapped)
+ result2 = (df.groupby('id')
+ .agg(sum_udf(plus_one(df.v)))
+ .sort('id'))
+ expected2 = (df.groupby('id')
+ .agg(sum(plus_one(df.v)))
+ .sort('id'))
+
+ # Mix group aggregate pandas UDF and scalar pandas UDF
+ result3 = (df.groupby('id')
+ .agg(sum_udf(plus_two(df.v)))
+ .sort('id'))
+ expected3 = (df.groupby('id')
+ .agg(sum(plus_two(df.v)))
+ .sort('id'))
+
+ # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped)
+ result4 = (df.groupby('id')
+ .agg(plus_two(sum_udf(df.v)))
+ .sort('id'))
+ expected4 = (df.groupby('id')
+ .agg(plus_two(sum(df.v)))
+ .sort('id'))
+
+ # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby
+ result5 = (df.groupby(plus_one(df.id))
+ .agg(plus_one(sum_udf(plus_one(df.v))))
+ .sort('plus_one(id)'))
+ expected5 = (df.groupby(plus_one(df.id))
+ .agg(plus_one(sum(plus_one(df.v))))
+ .sort('plus_one(id)'))
+
+ # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in
+ # groupby
+ result6 = (df.groupby(plus_two(df.id))
+ .agg(plus_two(sum_udf(plus_two(df.v))))
+ .sort('plus_two(id)'))
+ expected6 = (df.groupby(plus_two(df.id))
+ .agg(plus_two(sum(plus_two(df.v))))
+ .sort('plus_two(id)'))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+ self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+ self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+ self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
+ self.assertPandasEqual(expected6.toPandas(), result6.toPandas())
+
+ def test_multiple_udfs(self):
+ """
+ Test multiple group aggregate pandas UDFs in one agg function.
+ """
+ from pyspark.sql.functions import sum, mean
+
+ df = self.data
+ mean_udf = self.pandas_agg_mean_udf
+ sum_udf = self.pandas_agg_sum_udf
+ weighted_mean_udf = self.pandas_agg_weighted_mean_udf
+
+ result1 = (df.groupBy('id')
+ .agg(mean_udf(df.v),
+ sum_udf(df.v),
+ weighted_mean_udf(df.v, df.w))
+ .sort('id')
+ .toPandas())
+ expected1 = (df.groupBy('id')
+ .agg(mean(df.v),
+ sum(df.v),
+ mean(df.v).alias('weighted_mean(v, w)'))
+ .sort('id')
+ .toPandas())
+
+ self.assertPandasEqual(expected1, result1)
+
+ def test_complex_groupby(self):
+ from pyspark.sql.functions import sum
+
+ df = self.data
+ sum_udf = self.pandas_agg_sum_udf
+ plus_one = self.python_plus_one
+ plus_two = self.pandas_scalar_plus_two
+
+ # groupby one expression
+ result1 = df.groupby(df.v % 2).agg(sum_udf(df.v))
+ expected1 = df.groupby(df.v % 2).agg(sum(df.v))
+
+ # empty groupby
+ result2 = df.groupby().agg(sum_udf(df.v))
+ expected2 = df.groupby().agg(sum(df.v))
+
+ # groupby one column and one sql expression
+ result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2)
+ expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2)
+
+ # groupby one python UDF
+ result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v))
+ expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v))
+
+ # groupby one scalar pandas UDF
+ result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v))
+ expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v))
+
+ # groupby one expression and one python UDF
+ result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v))
+ expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v))
+
+ # groupby one expression and one scalar pandas UDF
+ result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)')
+ expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)')
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+ self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+ self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+ self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
+ self.assertPandasEqual(expected6.toPandas(), result6.toPandas())
+ self.assertPandasEqual(expected7.toPandas(), result7.toPandas())
+
+ def test_complex_expressions(self):
+ from pyspark.sql.functions import col, sum
+
+ df = self.data
+ plus_one = self.python_plus_one
+ plus_two = self.pandas_scalar_plus_two
+ sum_udf = self.pandas_agg_sum_udf
+
+ # Test complex expressions with sql expression, python UDF and
+ # group aggregate pandas UDF
+ result1 = (df.withColumn('v1', plus_one(df.v))
+ .withColumn('v2', df.v + 2)
+ .groupby(df.id, df.v % 2)
+ .agg(sum_udf(col('v')),
+ sum_udf(col('v1') + 3),
+ sum_udf(col('v2')) + 5,
+ plus_one(sum_udf(col('v1'))),
+ sum_udf(plus_one(col('v2'))))
+ .sort('id')
+ .toPandas())
+
+ expected1 = (df.withColumn('v1', df.v + 1)
+ .withColumn('v2', df.v + 2)
+ .groupby(df.id, df.v % 2)
+ .agg(sum(col('v')),
+ sum(col('v1') + 3),
+ sum(col('v2')) + 5,
+ plus_one(sum(col('v1'))),
+ sum(plus_one(col('v2'))))
+ .sort('id')
+ .toPandas())
+
+ # Test complex expressions with sql expression, scala pandas UDF and
+ # group aggregate pandas UDF
+ result2 = (df.withColumn('v1', plus_one(df.v))
+ .withColumn('v2', df.v + 2)
+ .groupby(df.id, df.v % 2)
+ .agg(sum_udf(col('v')),
+ sum_udf(col('v1') + 3),
+ sum_udf(col('v2')) + 5,
+ plus_two(sum_udf(col('v1'))),
+ sum_udf(plus_two(col('v2'))))
+ .sort('id')
+ .toPandas())
+
+ expected2 = (df.withColumn('v1', df.v + 1)
+ .withColumn('v2', df.v + 2)
+ .groupby(df.id, df.v % 2)
+ .agg(sum(col('v')),
+ sum(col('v1') + 3),
+ sum(col('v2')) + 5,
+ plus_two(sum(col('v1'))),
+ sum(plus_two(col('v2'))))
+ .sort('id')
+ .toPandas())
+
+ # Test sequential groupby aggregate
+ result3 = (df.groupby('id')
+ .agg(sum_udf(df.v).alias('v'))
+ .groupby('id')
+ .agg(sum_udf(col('v')))
+ .sort('id')
+ .toPandas())
+
+ expected3 = (df.groupby('id')
+ .agg(sum(df.v).alias('v'))
+ .groupby('id')
+ .agg(sum(col('v')))
+ .sort('id')
+ .toPandas())
+
+ self.assertPandasEqual(expected1, result1)
+ self.assertPandasEqual(expected2, result2)
+ self.assertPandasEqual(expected3, result3)
+
+ def test_retain_group_columns(self):
+ from pyspark.sql.functions import sum
+ with self.sql_conf({"spark.sql.retainGroupColumns": False}):
+ df = self.data
+ sum_udf = self.pandas_agg_sum_udf
+
+ result1 = df.groupby(df.id).agg(sum_udf(df.v))
+ expected1 = df.groupby(df.id).agg(sum(df.v))
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_array_type(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ df = self.data
+
+ array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG)
+ result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
+ self.assertEquals(result1.first()['v2'], [1.0, 2.0])
+
+ def test_invalid_args(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ plus_one = self.python_plus_one
+ mean_udf = self.pandas_agg_mean_udf
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ AnalysisException,
+ 'nor.*aggregate function'):
+ df.groupby(df.id).agg(plus_one(df.v)).collect()
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ AnalysisException,
+ 'aggregate function.*argument.*aggregate function'):
+ df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ AnalysisException,
+ 'mixture.*aggregate function.*group aggregate pandas UDF'):
+ df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
+
+ def test_register_vectorized_udf_basic(self):
+ from pyspark.sql.functions import pandas_udf
+ from pyspark.rdd import PythonEvalType
+
+ sum_pandas_udf = pandas_udf(
+ lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
+
+ self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
+ group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf)
+ self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
+ q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
+ actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
+ expected = [1, 5]
+ self.assertEqual(actual, expected)
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_pandas_udf_grouped_agg import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
new file mode 100644
index 0000000000000..bfecc071386e9
--- /dev/null
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
@@ -0,0 +1,531 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+import unittest
+
+from pyspark.sql import Row
+from pyspark.sql.types import *
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
+ pandas_requirement_message, pyarrow_requirement_message
+from pyspark.testing.utils import QuietTest
+
+
+@unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message)
+class GroupedMapPandasUDFTests(ReusedSQLTestCase):
+
+ @property
+ def data(self):
+ from pyspark.sql.functions import array, explode, col, lit
+ return self.spark.range(10).toDF('id') \
+ .withColumn("vs", array([lit(i) for i in range(20, 30)])) \
+ .withColumn("v", explode(col('vs'))).drop('vs')
+
+ def test_supported_types(self):
+ from decimal import Decimal
+ from distutils.version import LooseVersion
+ import pyarrow as pa
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ values = [
+ 1, 2, 3,
+ 4, 5, 1.1,
+ 2.2, Decimal(1.123),
+ [1, 2, 2], True, 'hello'
+ ]
+ output_fields = [
+ ('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()),
+ ('int', IntegerType()), ('long', LongType()), ('float', FloatType()),
+ ('double', DoubleType()), ('decim', DecimalType(10, 3)),
+ ('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType())
+ ]
+
+ # TODO: Add BinaryType to variables above once minimum pyarrow version is 0.10.0
+ if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"):
+ values.append(bytearray([0x01, 0x02]))
+ output_fields.append(('bin', BinaryType()))
+
+ output_schema = StructType([StructField(*x) for x in output_fields])
+ df = self.spark.createDataFrame([values], schema=output_schema)
+
+ # Different forms of group map pandas UDF, results of these are the same
+ udf1 = pandas_udf(
+ lambda pdf: pdf.assign(
+ byte=pdf.byte * 2,
+ short=pdf.short * 2,
+ int=pdf.int * 2,
+ long=pdf.long * 2,
+ float=pdf.float * 2,
+ double=pdf.double * 2,
+ decim=pdf.decim * 2,
+ bool=False if pdf.bool else True,
+ str=pdf.str + 'there',
+ array=pdf.array,
+ ),
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ udf2 = pandas_udf(
+ lambda _, pdf: pdf.assign(
+ byte=pdf.byte * 2,
+ short=pdf.short * 2,
+ int=pdf.int * 2,
+ long=pdf.long * 2,
+ float=pdf.float * 2,
+ double=pdf.double * 2,
+ decim=pdf.decim * 2,
+ bool=False if pdf.bool else True,
+ str=pdf.str + 'there',
+ array=pdf.array,
+ ),
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ udf3 = pandas_udf(
+ lambda key, pdf: pdf.assign(
+ id=key[0],
+ byte=pdf.byte * 2,
+ short=pdf.short * 2,
+ int=pdf.int * 2,
+ long=pdf.long * 2,
+ float=pdf.float * 2,
+ double=pdf.double * 2,
+ decim=pdf.decim * 2,
+ bool=False if pdf.bool else True,
+ str=pdf.str + 'there',
+ array=pdf.array,
+ ),
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
+ expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
+
+ result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
+ expected2 = expected1
+
+ result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
+ expected3 = expected1
+
+ self.assertPandasEqual(expected1, result1)
+ self.assertPandasEqual(expected2, result2)
+ self.assertPandasEqual(expected3, result3)
+
+ def test_array_type_correct(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
+
+ df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
+
+ output_schema = StructType(
+ [StructField('id', LongType()),
+ StructField('v', IntegerType()),
+ StructField('arr', ArrayType(LongType()))])
+
+ udf = pandas_udf(
+ lambda pdf: pdf,
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ result = df.groupby('id').apply(udf).sort('id').toPandas()
+ expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True)
+ self.assertPandasEqual(expected, result)
+
+ def test_register_grouped_map_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'):
+ self.spark.catalog.registerFunction("foo_udf", foo_udf)
+
+ def test_decorator(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+ df = self.data
+
+ @pandas_udf(
+ 'id long, v int, v1 double, v2 long',
+ PandasUDFType.GROUPED_MAP
+ )
+ def foo(pdf):
+ return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id)
+
+ result = df.groupby('id').apply(foo).sort('id').toPandas()
+ expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
+ self.assertPandasEqual(expected, result)
+
+ def test_coerce(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+ df = self.data
+
+ foo = pandas_udf(
+ lambda pdf: pdf,
+ 'id long, v double',
+ PandasUDFType.GROUPED_MAP
+ )
+
+ result = df.groupby('id').apply(foo).sort('id').toPandas()
+ expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
+ expected = expected.assign(v=expected.v.astype('float64'))
+ self.assertPandasEqual(expected, result)
+
+ def test_complex_groupby(self):
+ from pyspark.sql.functions import pandas_udf, col, PandasUDFType
+ df = self.data
+
+ @pandas_udf(
+ 'id long, v int, norm double',
+ PandasUDFType.GROUPED_MAP
+ )
+ def normalize(pdf):
+ v = pdf.v
+ return pdf.assign(norm=(v - v.mean()) / v.std())
+
+ result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas()
+ pdf = df.toPandas()
+ expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func)
+ expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
+ expected = expected.assign(norm=expected.norm.astype('float64'))
+ self.assertPandasEqual(expected, result)
+
+ def test_empty_groupby(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+ df = self.data
+
+ @pandas_udf(
+ 'id long, v int, norm double',
+ PandasUDFType.GROUPED_MAP
+ )
+ def normalize(pdf):
+ v = pdf.v
+ return pdf.assign(norm=(v - v.mean()) / v.std())
+
+ result = df.groupby().apply(normalize).sort('id', 'v').toPandas()
+ pdf = df.toPandas()
+ expected = normalize.func(pdf)
+ expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
+ expected = expected.assign(norm=expected.norm.astype('float64'))
+ self.assertPandasEqual(expected, result)
+
+ def test_datatype_string(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+ df = self.data
+
+ foo_udf = pandas_udf(
+ lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
+ 'id long, v int, v1 double, v2 long',
+ PandasUDFType.GROUPED_MAP
+ )
+
+ result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
+ expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
+ self.assertPandasEqual(expected, result)
+
+ def test_wrong_return_type(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*grouped map Pandas UDF.*MapType'):
+ pandas_udf(
+ lambda pdf: pdf,
+ 'id long, v map',
+ PandasUDFType.GROUPED_MAP)
+
+ def test_wrong_args(self):
+ from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType
+ df = self.data
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
+ df.groupby('id').apply(lambda x: x)
+ with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
+ df.groupby('id').apply(udf(lambda x: x, DoubleType()))
+ with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
+ df.groupby('id').apply(sum(df.v))
+ with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
+ df.groupby('id').apply(df.v + 1)
+ with self.assertRaisesRegexp(ValueError, 'Invalid function'):
+ df.groupby('id').apply(
+ pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
+ with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
+ df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType()))
+ with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
+ df.groupby('id').apply(
+ pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
+
+ def test_unsupported_types(self):
+ from distutils.version import LooseVersion
+ import pyarrow as pa
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*'
+ unsupported_types = [
+ StructField('map', MapType(StringType(), IntegerType())),
+ StructField('arr_ts', ArrayType(TimestampType())),
+ StructField('null', NullType()),
+ ]
+
+ # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0
+ if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+ unsupported_types.append(StructField('bin', BinaryType()))
+
+ for unsupported_type in unsupported_types:
+ schema = StructType([StructField('id', LongType(), True), unsupported_type])
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(NotImplementedError, common_err_msg):
+ pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
+
+ # Regression test for SPARK-23314
+ def test_timestamp_dst(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+ # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
+ dt = [datetime.datetime(2015, 11, 1, 0, 30),
+ datetime.datetime(2015, 11, 1, 1, 30),
+ datetime.datetime(2015, 11, 1, 2, 30)]
+ df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
+ foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP)
+ result = df.groupby('time').apply(foo_udf).sort('time')
+ self.assertPandasEqual(df.toPandas(), result.toPandas())
+
+ def test_udf_with_key(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+ df = self.data
+ pdf = df.toPandas()
+
+ def foo1(key, pdf):
+ import numpy as np
+ assert type(key) == tuple
+ assert type(key[0]) == np.int64
+
+ return pdf.assign(v1=key[0],
+ v2=pdf.v * key[0],
+ v3=pdf.v * pdf.id,
+ v4=pdf.v * pdf.id.mean())
+
+ def foo2(key, pdf):
+ import numpy as np
+ assert type(key) == tuple
+ assert type(key[0]) == np.int64
+ assert type(key[1]) == np.int32
+
+ return pdf.assign(v1=key[0],
+ v2=key[1],
+ v3=pdf.v * key[0],
+ v4=pdf.v + key[1])
+
+ def foo3(key, pdf):
+ assert type(key) == tuple
+ assert len(key) == 0
+ return pdf.assign(v1=pdf.v * pdf.id)
+
+ # v2 is int because numpy.int64 * pd.Series results in pd.Series
+ # v3 is long because pd.Series * pd.Series results in pd.Series
+ udf1 = pandas_udf(
+ foo1,
+ 'id long, v int, v1 long, v2 int, v3 long, v4 double',
+ PandasUDFType.GROUPED_MAP)
+
+ udf2 = pandas_udf(
+ foo2,
+ 'id long, v int, v1 long, v2 int, v3 int, v4 int',
+ PandasUDFType.GROUPED_MAP)
+
+ udf3 = pandas_udf(
+ foo3,
+ 'id long, v int, v1 long',
+ PandasUDFType.GROUPED_MAP)
+
+ # Test groupby column
+ result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
+ expected1 = pdf.groupby('id')\
+ .apply(lambda x: udf1.func((x.id.iloc[0],), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected1, result1)
+
+ # Test groupby expression
+ result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
+ expected2 = pdf.groupby(pdf.id % 2)\
+ .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected2, result2)
+
+ # Test complex groupby
+ result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
+ expected3 = pdf.groupby([pdf.id, pdf.v % 2])\
+ .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected3, result3)
+
+ # Test empty groupby
+ result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
+ expected4 = udf3.func((), pdf)
+ self.assertPandasEqual(expected4, result4)
+
+ def test_column_order(self):
+ from collections import OrderedDict
+ import pandas as pd
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ # Helper function to set column names from a list
+ def rename_pdf(pdf, names):
+ pdf.rename(columns={old: new for old, new in
+ zip(pd_result.columns, names)}, inplace=True)
+
+ df = self.data
+ grouped_df = df.groupby('id')
+ grouped_pdf = df.toPandas().groupby('id')
+
+ # Function returns a pdf with required column names, but order could be arbitrary using dict
+ def change_col_order(pdf):
+ # Constructing a DataFrame from a dict should result in the same order,
+ # but use from_items to ensure the pdf column order is different than schema
+ return pd.DataFrame.from_items([
+ ('id', pdf.id),
+ ('u', pdf.v * 2),
+ ('v', pdf.v)])
+
+ ordered_udf = pandas_udf(
+ change_col_order,
+ 'id long, v int, u int',
+ PandasUDFType.GROUPED_MAP
+ )
+
+ # The UDF result should assign columns by name from the pdf
+ result = grouped_df.apply(ordered_udf).sort('id', 'v')\
+ .select('id', 'u', 'v').toPandas()
+ pd_result = grouped_pdf.apply(change_col_order)
+ expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected, result)
+
+ # Function returns a pdf with positional columns, indexed by range
+ def range_col_order(pdf):
+ # Create a DataFrame with positional columns, fix types to long
+ return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64')
+
+ range_udf = pandas_udf(
+ range_col_order,
+ 'id long, u long, v long',
+ PandasUDFType.GROUPED_MAP
+ )
+
+ # The UDF result uses positional columns from the pdf
+ result = grouped_df.apply(range_udf).sort('id', 'v') \
+ .select('id', 'u', 'v').toPandas()
+ pd_result = grouped_pdf.apply(range_col_order)
+ rename_pdf(pd_result, ['id', 'u', 'v'])
+ expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected, result)
+
+ # Function returns a pdf with columns indexed with integers
+ def int_index(pdf):
+ return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)]))
+
+ int_index_udf = pandas_udf(
+ int_index,
+ 'id long, u int, v int',
+ PandasUDFType.GROUPED_MAP
+ )
+
+ # The UDF result should assign columns by position of integer index
+ result = grouped_df.apply(int_index_udf).sort('id', 'v') \
+ .select('id', 'u', 'v').toPandas()
+ pd_result = grouped_pdf.apply(int_index)
+ rename_pdf(pd_result, ['id', 'u', 'v'])
+ expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected, result)
+
+ @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP)
+ def column_name_typo(pdf):
+ return pd.DataFrame({'iid': pdf.id, 'v': pdf.v})
+
+ @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP)
+ def invalid_positional_types(pdf):
+ return pd.DataFrame([(u'a', 1.2)])
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(Exception, "KeyError: 'id'"):
+ grouped_df.apply(column_name_typo).collect()
+ with self.assertRaisesRegexp(Exception, "No cast implemented"):
+ grouped_df.apply(invalid_positional_types).collect()
+
+ def test_positional_assignment_conf(self):
+ import pandas as pd
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ with self.sql_conf({
+ "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}):
+
+ @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP)
+ def foo(_):
+ return pd.DataFrame([('hi', 1)], columns=['x', 'y'])
+
+ df = self.data
+ result = df.groupBy('id').apply(foo).select('a', 'b').collect()
+ for r in result:
+ self.assertEqual(r.a, 'hi')
+ self.assertEqual(r.b, 1)
+
+ def test_self_join_with_pandas(self):
+ import pyspark.sql.functions as F
+
+ @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
+ def dummy_pandas_udf(df):
+ return df[['key', 'col']]
+
+ df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
+ Row(key=2, col='C')])
+ df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf)
+
+ # this was throwing an AnalysisException before SPARK-24208
+ res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'),
+ F.col('temp0.key') == F.col('temp1.key'))
+ self.assertEquals(res.count(), 5)
+
+ def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
+ import pandas as pd
+ from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
+
+ df = self.spark.range(0, 10).toDF('v1')
+ df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
+ .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
+
+ result = df.groupby() \
+ .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]),
+ 'sum int',
+ PandasUDFType.GROUPED_MAP))
+
+ self.assertEquals(result.collect()[0]['sum'], 165)
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_pandas_udf_grouped_map import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
new file mode 100644
index 0000000000000..b303398850394
--- /dev/null
+++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
@@ -0,0 +1,808 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import datetime
+import os
+import shutil
+import sys
+import tempfile
+import time
+import unittest
+
+from pyspark.sql.types import Row
+from pyspark.sql.types import *
+from pyspark.sql.utils import AnalysisException
+from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled,\
+ test_not_compiled_message, have_pandas, have_pyarrow, pandas_requirement_message, \
+ pyarrow_requirement_message
+from pyspark.testing.utils import QuietTest
+
+
+@unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message)
+class ScalarPandasUDFTests(ReusedSQLTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedSQLTestCase.setUpClass()
+
+ # Synchronize default timezone between Python and Java
+ cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
+ tz = "UTC"
+ os.environ["TZ"] = tz
+ time.tzset()
+
+ cls.sc.environment["TZ"] = tz
+ cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+ @classmethod
+ def tearDownClass(cls):
+ del os.environ["TZ"]
+ if cls.tz_prev is not None:
+ os.environ["TZ"] = cls.tz_prev
+ time.tzset()
+ ReusedSQLTestCase.tearDownClass()
+
+ @property
+ def nondeterministic_vectorized_udf(self):
+ from pyspark.sql.functions import pandas_udf
+
+ @pandas_udf('double')
+ def random_udf(v):
+ import pandas as pd
+ import numpy as np
+ return pd.Series(np.random.random(len(v)))
+ random_udf = random_udf.asNondeterministic()
+ return random_udf
+
+ def test_pandas_udf_tokenize(self):
+ from pyspark.sql.functions import pandas_udf
+ tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')),
+ ArrayType(StringType()))
+ self.assertEqual(tokenize.returnType, ArrayType(StringType()))
+ df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
+ result = df.select(tokenize("vals").alias("hi"))
+ self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect())
+
+ def test_pandas_udf_nested_arrays(self):
+ from pyspark.sql.functions import pandas_udf
+ tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]),
+ ArrayType(ArrayType(StringType())))
+ self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType())))
+ df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
+ result = df.select(tokenize("vals").alias("hi"))
+ self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect())
+
+ def test_vectorized_udf_basic(self):
+ from pyspark.sql.functions import pandas_udf, col, array
+ df = self.spark.range(10).select(
+ col('id').cast('string').alias('str'),
+ col('id').cast('int').alias('int'),
+ col('id').alias('long'),
+ col('id').cast('float').alias('float'),
+ col('id').cast('double').alias('double'),
+ col('id').cast('decimal').alias('decimal'),
+ col('id').cast('boolean').alias('bool'),
+ array(col('id')).alias('array_long'))
+ f = lambda x: x
+ str_f = pandas_udf(f, StringType())
+ int_f = pandas_udf(f, IntegerType())
+ long_f = pandas_udf(f, LongType())
+ float_f = pandas_udf(f, FloatType())
+ double_f = pandas_udf(f, DoubleType())
+ decimal_f = pandas_udf(f, DecimalType())
+ bool_f = pandas_udf(f, BooleanType())
+ array_long_f = pandas_udf(f, ArrayType(LongType()))
+ res = df.select(str_f(col('str')), int_f(col('int')),
+ long_f(col('long')), float_f(col('float')),
+ double_f(col('double')), decimal_f('decimal'),
+ bool_f(col('bool')), array_long_f('array_long'))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_register_nondeterministic_vectorized_udf_basic(self):
+ from pyspark.sql.functions import pandas_udf
+ from pyspark.rdd import PythonEvalType
+ import random
+ random_pandas_udf = pandas_udf(
+ lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic()
+ self.assertEqual(random_pandas_udf.deterministic, False)
+ self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+ nondeterministic_pandas_udf = self.spark.catalog.registerFunction(
+ "randomPandasUDF", random_pandas_udf)
+ self.assertEqual(nondeterministic_pandas_udf.deterministic, False)
+ self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+ [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect()
+ self.assertEqual(row[0], 7)
+
+ def test_vectorized_udf_null_boolean(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(True,), (True,), (None,), (False,)]
+ schema = StructType().add("bool", BooleanType())
+ df = self.spark.createDataFrame(data, schema)
+ bool_f = pandas_udf(lambda x: x, BooleanType())
+ res = df.select(bool_f(col('bool')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_byte(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("byte", ByteType())
+ df = self.spark.createDataFrame(data, schema)
+ byte_f = pandas_udf(lambda x: x, ByteType())
+ res = df.select(byte_f(col('byte')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_short(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("short", ShortType())
+ df = self.spark.createDataFrame(data, schema)
+ short_f = pandas_udf(lambda x: x, ShortType())
+ res = df.select(short_f(col('short')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_int(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("int", IntegerType())
+ df = self.spark.createDataFrame(data, schema)
+ int_f = pandas_udf(lambda x: x, IntegerType())
+ res = df.select(int_f(col('int')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_long(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("long", LongType())
+ df = self.spark.createDataFrame(data, schema)
+ long_f = pandas_udf(lambda x: x, LongType())
+ res = df.select(long_f(col('long')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_float(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(3.0,), (5.0,), (-1.0,), (None,)]
+ schema = StructType().add("float", FloatType())
+ df = self.spark.createDataFrame(data, schema)
+ float_f = pandas_udf(lambda x: x, FloatType())
+ res = df.select(float_f(col('float')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_double(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(3.0,), (5.0,), (-1.0,), (None,)]
+ schema = StructType().add("double", DoubleType())
+ df = self.spark.createDataFrame(data, schema)
+ double_f = pandas_udf(lambda x: x, DoubleType())
+ res = df.select(double_f(col('double')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_decimal(self):
+ from decimal import Decimal
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
+ schema = StructType().add("decimal", DecimalType(38, 18))
+ df = self.spark.createDataFrame(data, schema)
+ decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18))
+ res = df.select(decimal_f(col('decimal')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_string(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [("foo",), (None,), ("bar",), ("bar",)]
+ schema = StructType().add("str", StringType())
+ df = self.spark.createDataFrame(data, schema)
+ str_f = pandas_udf(lambda x: x, StringType())
+ res = df.select(str_f(col('str')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_string_in_udf(self):
+ from pyspark.sql.functions import pandas_udf, col
+ import pandas as pd
+ df = self.spark.range(10)
+ str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType())
+ actual = df.select(str_f(col('id')))
+ expected = df.select(col('id').cast('string'))
+ self.assertEquals(expected.collect(), actual.collect())
+
+ def test_vectorized_udf_datatype_string(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10).select(
+ col('id').cast('string').alias('str'),
+ col('id').cast('int').alias('int'),
+ col('id').alias('long'),
+ col('id').cast('float').alias('float'),
+ col('id').cast('double').alias('double'),
+ col('id').cast('decimal').alias('decimal'),
+ col('id').cast('boolean').alias('bool'))
+ f = lambda x: x
+ str_f = pandas_udf(f, 'string')
+ int_f = pandas_udf(f, 'integer')
+ long_f = pandas_udf(f, 'long')
+ float_f = pandas_udf(f, 'float')
+ double_f = pandas_udf(f, 'double')
+ decimal_f = pandas_udf(f, 'decimal(38, 18)')
+ bool_f = pandas_udf(f, 'boolean')
+ res = df.select(str_f(col('str')), int_f(col('int')),
+ long_f(col('long')), float_f(col('float')),
+ double_f(col('double')), decimal_f('decimal'),
+ bool_f(col('bool')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_binary(self):
+ from distutils.version import LooseVersion
+ import pyarrow as pa
+ from pyspark.sql.functions import pandas_udf, col
+ if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*scalar Pandas UDF.*BinaryType'):
+ pandas_udf(lambda x: x, BinaryType())
+ else:
+ data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)]
+ schema = StructType().add("binary", BinaryType())
+ df = self.spark.createDataFrame(data, schema)
+ str_f = pandas_udf(lambda x: x, BinaryType())
+ res = df.select(str_f(col('binary')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_array_type(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [([1, 2],), ([3, 4],)]
+ array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
+ df = self.spark.createDataFrame(data, schema=array_schema)
+ array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()))
+ result = df.select(array_f(col('array')))
+ self.assertEquals(df.collect(), result.collect())
+
+ def test_vectorized_udf_null_array(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
+ array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
+ df = self.spark.createDataFrame(data, schema=array_schema)
+ array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()))
+ result = df.select(array_f(col('array')))
+ self.assertEquals(df.collect(), result.collect())
+
+ def test_vectorized_udf_complex(self):
+ from pyspark.sql.functions import pandas_udf, col, expr
+ df = self.spark.range(10).select(
+ col('id').cast('int').alias('a'),
+ col('id').cast('int').alias('b'),
+ col('id').cast('double').alias('c'))
+ add = pandas_udf(lambda x, y: x + y, IntegerType())
+ power2 = pandas_udf(lambda x: 2 ** x, IntegerType())
+ mul = pandas_udf(lambda x, y: x * y, DoubleType())
+ res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c')))
+ expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c'))
+ self.assertEquals(expected.collect(), res.collect())
+
+ def test_vectorized_udf_exception(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10)
+ raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'):
+ df.select(raise_exception(col('id'))).collect()
+
+ def test_vectorized_udf_invalid_length(self):
+ from pyspark.sql.functions import pandas_udf, col
+ import pandas as pd
+ df = self.spark.range(10)
+ raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ Exception,
+ 'Result vector from pandas_udf was not the required length'):
+ df.select(raise_exception(col('id'))).collect()
+
+ def test_vectorized_udf_chained(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10)
+ f = pandas_udf(lambda x: x + 1, LongType())
+ g = pandas_udf(lambda x: x - 1, LongType())
+ res = df.select(g(f(col('id'))))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_wrong_return_type(self):
+ from pyspark.sql.functions import pandas_udf
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*scalar Pandas UDF.*MapType'):
+ pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
+
+ def test_vectorized_udf_return_scalar(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10)
+ f = pandas_udf(lambda x: 1.0, DoubleType())
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'):
+ df.select(f(col('id'))).collect()
+
+ def test_vectorized_udf_decorator(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10)
+
+ @pandas_udf(returnType=LongType())
+ def identity(x):
+ return x
+ res = df.select(identity(col('id')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_empty_partition(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
+ f = pandas_udf(lambda x: x, LongType())
+ res = df.select(f(col('id')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_varargs(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
+ f = pandas_udf(lambda *v: v[0], LongType())
+ res = df.select(f(col('id')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_unsupported_types(self):
+ from pyspark.sql.functions import pandas_udf
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*scalar Pandas UDF.*MapType'):
+ pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
+
+ def test_vectorized_udf_dates(self):
+ from pyspark.sql.functions import pandas_udf, col
+ from datetime import date
+ schema = StructType().add("idx", LongType()).add("date", DateType())
+ data = [(0, date(1969, 1, 1),),
+ (1, date(2012, 2, 2),),
+ (2, None,),
+ (3, date(2100, 4, 4),)]
+ df = self.spark.createDataFrame(data, schema=schema)
+
+ date_copy = pandas_udf(lambda t: t, returnType=DateType())
+ df = df.withColumn("date_copy", date_copy(col("date")))
+
+ @pandas_udf(returnType=StringType())
+ def check_data(idx, date, date_copy):
+ import pandas as pd
+ msgs = []
+ is_equal = date.isnull()
+ for i in range(len(idx)):
+ if (is_equal[i] and data[idx[i]][1] is None) or \
+ date[i] == data[idx[i]][1]:
+ msgs.append(None)
+ else:
+ msgs.append(
+ "date values are not equal (date='%s': data[%d][1]='%s')"
+ % (date[i], idx[i], data[idx[i]][1]))
+ return pd.Series(msgs)
+
+ result = df.withColumn("check_data",
+ check_data(col("idx"), col("date"), col("date_copy"))).collect()
+
+ self.assertEquals(len(data), len(result))
+ for i in range(len(result)):
+ self.assertEquals(data[i][1], result[i][1]) # "date" col
+ self.assertEquals(data[i][1], result[i][2]) # "date_copy" col
+ self.assertIsNone(result[i][3]) # "check_data" col
+
+ def test_vectorized_udf_timestamps(self):
+ from pyspark.sql.functions import pandas_udf, col
+ from datetime import datetime
+ schema = StructType([
+ StructField("idx", LongType(), True),
+ StructField("timestamp", TimestampType(), True)])
+ data = [(0, datetime(1969, 1, 1, 1, 1, 1)),
+ (1, datetime(2012, 2, 2, 2, 2, 2)),
+ (2, None),
+ (3, datetime(2100, 3, 3, 3, 3, 3))]
+
+ df = self.spark.createDataFrame(data, schema=schema)
+
+ # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc
+ f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType())
+ df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp")))
+
+ @pandas_udf(returnType=StringType())
+ def check_data(idx, timestamp, timestamp_copy):
+ import pandas as pd
+ msgs = []
+ is_equal = timestamp.isnull() # use this array to check values are equal
+ for i in range(len(idx)):
+ # Check that timestamps are as expected in the UDF
+ if (is_equal[i] and data[idx[i]][1] is None) or \
+ timestamp[i].to_pydatetime() == data[idx[i]][1]:
+ msgs.append(None)
+ else:
+ msgs.append(
+ "timestamp values are not equal (timestamp='%s': data[%d][1]='%s')"
+ % (timestamp[i], idx[i], data[idx[i]][1]))
+ return pd.Series(msgs)
+
+ result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"),
+ col("timestamp_copy"))).collect()
+ # Check that collection values are correct
+ self.assertEquals(len(data), len(result))
+ for i in range(len(result)):
+ self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
+ self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col
+ self.assertIsNone(result[i][3]) # "check_data" col
+
+ def test_vectorized_udf_return_timestamp_tz(self):
+ from pyspark.sql.functions import pandas_udf, col
+ import pandas as pd
+ df = self.spark.range(10)
+
+ @pandas_udf(returnType=TimestampType())
+ def gen_timestamps(id):
+ ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id]
+ return pd.Series(ts)
+
+ result = df.withColumn("ts", gen_timestamps(col("id"))).collect()
+ spark_ts_t = TimestampType()
+ for r in result:
+ i, ts = r
+ ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime()
+ expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
+ self.assertEquals(expected, ts)
+
+ def test_vectorized_udf_check_config(self):
+ from pyspark.sql.functions import pandas_udf, col
+ import pandas as pd
+ with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
+ df = self.spark.range(10, numPartitions=1)
+
+ @pandas_udf(returnType=LongType())
+ def check_records_per_batch(x):
+ return pd.Series(x.size).repeat(x.size)
+
+ result = df.select(check_records_per_batch(col("id"))).collect()
+ for (r,) in result:
+ self.assertTrue(r <= 3)
+
+ def test_vectorized_udf_timestamps_respect_session_timezone(self):
+ from pyspark.sql.functions import pandas_udf, col
+ from datetime import datetime
+ import pandas as pd
+ schema = StructType([
+ StructField("idx", LongType(), True),
+ StructField("timestamp", TimestampType(), True)])
+ data = [(1, datetime(1969, 1, 1, 1, 1, 1)),
+ (2, datetime(2012, 2, 2, 2, 2, 2)),
+ (3, None),
+ (4, datetime(2100, 3, 3, 3, 3, 3))]
+ df = self.spark.createDataFrame(data, schema=schema)
+
+ f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType())
+ internal_value = pandas_udf(
+ lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())
+
+ timezone = "America/New_York"
+ with self.sql_conf({
+ "spark.sql.execution.pandas.respectSessionTimeZone": False,
+ "spark.sql.session.timeZone": timezone}):
+ df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
+ .withColumn("internal_value", internal_value(col("timestamp")))
+ result_la = df_la.select(col("idx"), col("internal_value")).collect()
+ # Correct result_la by adjusting 3 hours difference between Los Angeles and New York
+ diff = 5 * 60 * 60 * 1000 * 1000 * 1000
+ result_la_corrected = \
+ df_la.select(col("idx"), col("tscopy"), col("internal_value") - diff).collect()
+
+ with self.sql_conf({
+ "spark.sql.execution.pandas.respectSessionTimeZone": True,
+ "spark.sql.session.timeZone": timezone}):
+ df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
+ .withColumn("internal_value", internal_value(col("timestamp")))
+ result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()
+
+ self.assertNotEqual(result_ny, result_la)
+ self.assertEqual(result_ny, result_la_corrected)
+
+ def test_nondeterministic_vectorized_udf(self):
+ # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
+ from pyspark.sql.functions import pandas_udf, col
+
+ @pandas_udf('double')
+ def plus_ten(v):
+ return v + 10
+ random_udf = self.nondeterministic_vectorized_udf
+
+ df = self.spark.range(10).withColumn('rand', random_udf(col('id')))
+ result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas()
+
+ self.assertEqual(random_udf.deterministic, False)
+ self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10))
+
+ def test_nondeterministic_vectorized_udf_in_aggregate(self):
+ from pyspark.sql.functions import sum
+
+ df = self.spark.range(10)
+ random_udf = self.nondeterministic_vectorized_udf
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
+ df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
+ with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
+ df.agg(sum(random_udf(df.id))).collect()
+
+ def test_register_vectorized_udf_basic(self):
+ from pyspark.rdd import PythonEvalType
+ from pyspark.sql.functions import pandas_udf, col, expr
+ df = self.spark.range(10).select(
+ col('id').cast('int').alias('a'),
+ col('id').cast('int').alias('b'))
+ original_add = pandas_udf(lambda x, y: x + y, IntegerType())
+ self.assertEqual(original_add.deterministic, True)
+ self.assertEqual(original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+ new_add = self.spark.catalog.registerFunction("add1", original_add)
+ res1 = df.select(new_add(col('a'), col('b')))
+ res2 = self.spark.sql(
+ "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t")
+ expected = df.select(expr('a + b'))
+ self.assertEquals(expected.collect(), res1.collect())
+ self.assertEquals(expected.collect(), res2.collect())
+
+ # Regression test for SPARK-23314
+ def test_timestamp_dst(self):
+ from pyspark.sql.functions import pandas_udf
+ # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
+ dt = [datetime.datetime(2015, 11, 1, 0, 30),
+ datetime.datetime(2015, 11, 1, 1, 30),
+ datetime.datetime(2015, 11, 1, 2, 30)]
+ df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
+ foo_udf = pandas_udf(lambda x: x, 'timestamp')
+ result = df.withColumn('time', foo_udf(df.time))
+ self.assertEquals(df.collect(), result.collect())
+
+ @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
+ def test_type_annotation(self):
+ from pyspark.sql.functions import pandas_udf
+ # Regression test to check if type hints can be used. See SPARK-23569.
+ # Note that it throws an error during compilation in lower Python versions if 'exec'
+ # is not used. Also, note that we explicitly use another dictionary to avoid modifications
+ # in the current 'locals()'.
+ #
+ # Hyukjin: I think it's an ugly way to test issues about syntax specific in
+ # higher versions of Python, which we shouldn't encourage. This was the last resort
+ # I could come up with at that time.
+ _locals = {}
+ exec(
+ "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col",
+ _locals)
+ df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
+ self.assertEqual(df.first()[0], 0)
+
+ def test_mixed_udf(self):
+ import pandas as pd
+ from pyspark.sql.functions import col, udf, pandas_udf
+
+ df = self.spark.range(0, 1).toDF('v')
+
+ # Test mixture of multiple UDFs and Pandas UDFs.
+
+ @udf('int')
+ def f1(x):
+ assert type(x) == int
+ return x + 1
+
+ @pandas_udf('int')
+ def f2(x):
+ assert type(x) == pd.Series
+ return x + 10
+
+ @udf('int')
+ def f3(x):
+ assert type(x) == int
+ return x + 100
+
+ @pandas_udf('int')
+ def f4(x):
+ assert type(x) == pd.Series
+ return x + 1000
+
+ # Test single expression with chained UDFs
+ df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v'])))
+ df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
+ df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v'])))))
+ df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
+ df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))
+
+ expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11)
+ expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111)
+ expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111)
+ expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011)
+ expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101)
+
+ self.assertEquals(expected_chained_1.collect(), df_chained_1.collect())
+ self.assertEquals(expected_chained_2.collect(), df_chained_2.collect())
+ self.assertEquals(expected_chained_3.collect(), df_chained_3.collect())
+ self.assertEquals(expected_chained_4.collect(), df_chained_4.collect())
+ self.assertEquals(expected_chained_5.collect(), df_chained_5.collect())
+
+ # Test multiple mixed UDF expressions in a single projection
+ df_multi_1 = df \
+ .withColumn('f1', f1(col('v'))) \
+ .withColumn('f2', f2(col('v'))) \
+ .withColumn('f3', f3(col('v'))) \
+ .withColumn('f4', f4(col('v'))) \
+ .withColumn('f2_f1', f2(col('f1'))) \
+ .withColumn('f3_f1', f3(col('f1'))) \
+ .withColumn('f4_f1', f4(col('f1'))) \
+ .withColumn('f3_f2', f3(col('f2'))) \
+ .withColumn('f4_f2', f4(col('f2'))) \
+ .withColumn('f4_f3', f4(col('f3'))) \
+ .withColumn('f3_f2_f1', f3(col('f2_f1'))) \
+ .withColumn('f4_f2_f1', f4(col('f2_f1'))) \
+ .withColumn('f4_f3_f1', f4(col('f3_f1'))) \
+ .withColumn('f4_f3_f2', f4(col('f3_f2'))) \
+ .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1')))
+
+ # Test mixed udfs in a single expression
+ df_multi_2 = df \
+ .withColumn('f1', f1(col('v'))) \
+ .withColumn('f2', f2(col('v'))) \
+ .withColumn('f3', f3(col('v'))) \
+ .withColumn('f4', f4(col('v'))) \
+ .withColumn('f2_f1', f2(f1(col('v')))) \
+ .withColumn('f3_f1', f3(f1(col('v')))) \
+ .withColumn('f4_f1', f4(f1(col('v')))) \
+ .withColumn('f3_f2', f3(f2(col('v')))) \
+ .withColumn('f4_f2', f4(f2(col('v')))) \
+ .withColumn('f4_f3', f4(f3(col('v')))) \
+ .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \
+ .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \
+ .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \
+ .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
+ .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))
+
+ expected = df \
+ .withColumn('f1', df['v'] + 1) \
+ .withColumn('f2', df['v'] + 10) \
+ .withColumn('f3', df['v'] + 100) \
+ .withColumn('f4', df['v'] + 1000) \
+ .withColumn('f2_f1', df['v'] + 11) \
+ .withColumn('f3_f1', df['v'] + 101) \
+ .withColumn('f4_f1', df['v'] + 1001) \
+ .withColumn('f3_f2', df['v'] + 110) \
+ .withColumn('f4_f2', df['v'] + 1010) \
+ .withColumn('f4_f3', df['v'] + 1100) \
+ .withColumn('f3_f2_f1', df['v'] + 111) \
+ .withColumn('f4_f2_f1', df['v'] + 1011) \
+ .withColumn('f4_f3_f1', df['v'] + 1101) \
+ .withColumn('f4_f3_f2', df['v'] + 1110) \
+ .withColumn('f4_f3_f2_f1', df['v'] + 1111)
+
+ self.assertEquals(expected.collect(), df_multi_1.collect())
+ self.assertEquals(expected.collect(), df_multi_2.collect())
+
+ def test_mixed_udf_and_sql(self):
+ import pandas as pd
+ from pyspark.sql import Column
+ from pyspark.sql.functions import udf, pandas_udf
+
+ df = self.spark.range(0, 1).toDF('v')
+
+ # Test mixture of UDFs, Pandas UDFs and SQL expression.
+
+ @udf('int')
+ def f1(x):
+ assert type(x) == int
+ return x + 1
+
+ def f2(x):
+ assert type(x) == Column
+ return x + 10
+
+ @pandas_udf('int')
+ def f3(x):
+ assert type(x) == pd.Series
+ return x + 100
+
+ df1 = df.withColumn('f1', f1(df['v'])) \
+ .withColumn('f2', f2(df['v'])) \
+ .withColumn('f3', f3(df['v'])) \
+ .withColumn('f1_f2', f1(f2(df['v']))) \
+ .withColumn('f1_f3', f1(f3(df['v']))) \
+ .withColumn('f2_f1', f2(f1(df['v']))) \
+ .withColumn('f2_f3', f2(f3(df['v']))) \
+ .withColumn('f3_f1', f3(f1(df['v']))) \
+ .withColumn('f3_f2', f3(f2(df['v']))) \
+ .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \
+ .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \
+ .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \
+ .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \
+ .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
+ .withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
+
+ expected = df.withColumn('f1', df['v'] + 1) \
+ .withColumn('f2', df['v'] + 10) \
+ .withColumn('f3', df['v'] + 100) \
+ .withColumn('f1_f2', df['v'] + 11) \
+ .withColumn('f1_f3', df['v'] + 101) \
+ .withColumn('f2_f1', df['v'] + 11) \
+ .withColumn('f2_f3', df['v'] + 110) \
+ .withColumn('f3_f1', df['v'] + 101) \
+ .withColumn('f3_f2', df['v'] + 110) \
+ .withColumn('f1_f2_f3', df['v'] + 111) \
+ .withColumn('f1_f3_f2', df['v'] + 111) \
+ .withColumn('f2_f1_f3', df['v'] + 111) \
+ .withColumn('f2_f3_f1', df['v'] + 111) \
+ .withColumn('f3_f1_f2', df['v'] + 111) \
+ .withColumn('f3_f2_f1', df['v'] + 111)
+
+ self.assertEquals(expected.collect(), df1.collect())
+
+ # SPARK-24721
+ @unittest.skipIf(not test_compiled, test_not_compiled_message)
+ def test_datasource_with_udf(self):
+ # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF
+ # This needs to a separate test because Arrow dependency is optional
+ import pandas as pd
+ import numpy as np
+ from pyspark.sql.functions import pandas_udf, lit, col
+
+ path = tempfile.mkdtemp()
+ shutil.rmtree(path)
+
+ try:
+ self.spark.range(1).write.mode("overwrite").format('csv').save(path)
+ filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
+ datasource_df = self.spark.read \
+ .format("org.apache.spark.sql.sources.SimpleScanSource") \
+ .option('from', 0).option('to', 1).load().toDF('i')
+ datasource_v2_df = self.spark.read \
+ .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
+ .load().toDF('i', 'j')
+
+ c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1))
+ c2 = pandas_udf(lambda x: x + 1, 'int')(col('i'))
+
+ f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))
+ f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i'))
+
+ for df in [filesource_df, datasource_df, datasource_v2_df]:
+ result = df.withColumn('c', c1)
+ expected = df.withColumn('c', lit(2))
+ self.assertEquals(expected.collect(), result.collect())
+
+ for df in [filesource_df, datasource_df, datasource_v2_df]:
+ result = df.withColumn('c', c2)
+ expected = df.withColumn('c', col('i') + 1)
+ self.assertEquals(expected.collect(), result.collect())
+
+ for df in [filesource_df, datasource_df, datasource_v2_df]:
+ for f in [f1, f2]:
+ result = df.filter(f)
+ self.assertEquals(0, result.count())
+ finally:
+ shutil.rmtree(path)
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_pandas_udf_scalar import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py
new file mode 100644
index 0000000000000..f0e6d2696df62
--- /dev/null
+++ b/python/pyspark/sql/tests/test_pandas_udf_window.py
@@ -0,0 +1,263 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.utils import AnalysisException
+from pyspark.sql.window import Window
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
+ pandas_requirement_message, pyarrow_requirement_message
+from pyspark.testing.utils import QuietTest
+
+
+@unittest.skipIf(
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message)
+class WindowPandasUDFTests(ReusedSQLTestCase):
+ @property
+ def data(self):
+ from pyspark.sql.functions import array, explode, col, lit
+ return self.spark.range(10).toDF('id') \
+ .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
+ .withColumn("v", explode(col('vs'))) \
+ .drop('vs') \
+ .withColumn('w', lit(1.0))
+
+ @property
+ def python_plus_one(self):
+ from pyspark.sql.functions import udf
+ return udf(lambda v: v + 1, 'double')
+
+ @property
+ def pandas_scalar_time_two(self):
+ from pyspark.sql.functions import pandas_udf
+ return pandas_udf(lambda v: v * 2, 'double')
+
+ @property
+ def pandas_agg_mean_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def avg(v):
+ return v.mean()
+ return avg
+
+ @property
+ def pandas_agg_max_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def max(v):
+ return v.max()
+ return max
+
+ @property
+ def pandas_agg_min_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def min(v):
+ return v.min()
+ return min
+
+ @property
+ def unbounded_window(self):
+ return Window.partitionBy('id') \
+ .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+
+ @property
+ def ordered_window(self):
+ return Window.partitionBy('id').orderBy('v')
+
+ @property
+ def unpartitioned_window(self):
+ return Window.partitionBy()
+
+ def test_simple(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w = self.unbounded_window
+
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w))
+ expected1 = df.withColumn('mean_v', mean(df['v']).over(w))
+
+ result2 = df.select(mean_udf(df['v']).over(w))
+ expected2 = df.select(mean(df['v']).over(w))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+ def test_multiple_udfs(self):
+ from pyspark.sql.functions import max, min, mean
+
+ df = self.data
+ w = self.unbounded_window
+
+ result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \
+ .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
+ .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
+
+ expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \
+ .withColumn('max_v', max(df['v']).over(w)) \
+ .withColumn('min_w', min(df['w']).over(w))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_replace_existing(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w = self.unbounded_window
+
+ result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w))
+ expected1 = df.withColumn('v', mean(df['v']).over(w))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_mixed_sql(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w = self.unbounded_window
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1)
+ expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1)
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_mixed_udf(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w = self.unbounded_window
+
+ plus_one = self.python_plus_one
+ time_two = self.pandas_scalar_time_two
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn(
+ 'v2',
+ plus_one(mean_udf(plus_one(df['v'])).over(w)))
+ expected1 = df.withColumn(
+ 'v2',
+ plus_one(mean(plus_one(df['v'])).over(w)))
+
+ result2 = df.withColumn(
+ 'v2',
+ time_two(mean_udf(time_two(df['v'])).over(w)))
+ expected2 = df.withColumn(
+ 'v2',
+ time_two(mean(time_two(df['v'])).over(w)))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+ def test_without_partitionBy(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w = self.unpartitioned_window
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn('v2', mean_udf(df['v']).over(w))
+ expected1 = df.withColumn('v2', mean(df['v']).over(w))
+
+ result2 = df.select(mean_udf(df['v']).over(w))
+ expected2 = df.select(mean(df['v']).over(w))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+ def test_mixed_sql_and_udf(self):
+ from pyspark.sql.functions import max, min, rank, col
+
+ df = self.data
+ w = self.unbounded_window
+ ow = self.ordered_window
+ max_udf = self.pandas_agg_max_udf
+ min_udf = self.pandas_agg_min_udf
+
+ result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w))
+ expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w))
+
+ # Test mixing sql window function and window udf in the same expression
+ result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w))
+ expected2 = expected1
+
+ # Test chaining sql aggregate function and udf
+ result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
+ .withColumn('min_v', min(df['v']).over(w)) \
+ .withColumn('v_diff', col('max_v') - col('min_v')) \
+ .drop('max_v', 'min_v')
+ expected3 = expected1
+
+ # Test mixing sql window function and udf
+ result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
+ .withColumn('rank', rank().over(ow))
+ expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
+ .withColumn('rank', rank().over(ow))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+ self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+ self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+
+ def test_array_type(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ df = self.data
+ w = self.unbounded_window
+
+ array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG)
+ result1 = df.withColumn('v2', array_udf(df['v']).over(w))
+ self.assertEquals(result1.first()['v2'], [1.0, 2.0])
+
+ def test_invalid_args(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ df = self.data
+ w = self.unbounded_window
+ ow = self.ordered_window
+ mean_udf = self.pandas_agg_mean_udf
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ AnalysisException,
+ '.*not supported within a window function'):
+ foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
+ df.withColumn('v2', foo_udf(df['v']).over(w))
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ AnalysisException,
+ '.*Only unbounded window frame is supported.*'):
+ df.withColumn('mean_v', mean_udf(df['v']).over(ow))
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_pandas_udf_window import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
new file mode 100644
index 0000000000000..2f8712d7631f5
--- /dev/null
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -0,0 +1,154 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import tempfile
+
+from pyspark.sql.types import *
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class ReadwriterTests(ReusedSQLTestCase):
+
+ def test_save_and_load(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.write.json(tmpPath)
+ actual = self.spark.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.spark.read.json(tmpPath, schema)
+ self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
+
+ df.write.json(tmpPath, "overwrite")
+ actual = self.spark.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ df.write.save(format="json", mode="overwrite", path=tmpPath,
+ noUse="this options will not be used in save.")
+ actual = self.spark.read.load(format="json", path=tmpPath,
+ noUse="this options will not be used in load.")
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ actual = self.spark.read.load(path=tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ csvpath = os.path.join(tempfile.mkdtemp(), 'data')
+ df.write.option('quote', None).format('csv').save(csvpath)
+
+ shutil.rmtree(tmpPath)
+
+ def test_save_and_load_builder(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.write.json(tmpPath)
+ actual = self.spark.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.spark.read.json(tmpPath, schema)
+ self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
+
+ df.write.mode("overwrite").json(tmpPath)
+ actual = self.spark.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
+ .option("noUse", "this option will not be used in save.")\
+ .format("json").save(path=tmpPath)
+ actual =\
+ self.spark.read.format("json")\
+ .load(path=tmpPath, noUse="this options will not be used in load.")
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ actual = self.spark.read.load(path=tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
+
+ def test_bucketed_write(self):
+ data = [
+ (1, "foo", 3.0), (2, "foo", 5.0),
+ (3, "bar", -1.0), (4, "bar", 6.0),
+ ]
+ df = self.spark.createDataFrame(data, ["x", "y", "z"])
+
+ def count_bucketed_cols(names, table="pyspark_bucket"):
+ """Given a sequence of column names and a table name
+ query the catalog and return number o columns which are
+ used for bucketing
+ """
+ cols = self.spark.catalog.listColumns(table)
+ num = len([c for c in cols if c.name in names and c.isBucket])
+ return num
+
+ with self.table("pyspark_bucket"):
+ # Test write with one bucketing column
+ df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
+ self.assertEqual(count_bucketed_cols(["x"]), 1)
+ self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
+
+ # Test write two bucketing columns
+ df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
+ self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
+ self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
+
+ # Test write with bucket and sort
+ df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
+ self.assertEqual(count_bucketed_cols(["x"]), 1)
+ self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
+
+ # Test write with a list of columns
+ df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
+ self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
+ self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
+
+ # Test write with bucket and sort with a list of columns
+ (df.write.bucketBy(2, "x")
+ .sortBy(["y", "z"])
+ .mode("overwrite").saveAsTable("pyspark_bucket"))
+ self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
+
+ # Test write with bucket and sort with multiple columns
+ (df.write.bucketBy(2, "x")
+ .sortBy("y", "z")
+ .mode("overwrite").saveAsTable("pyspark_bucket"))
+ self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.test_readwriter import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py
new file mode 100644
index 0000000000000..8707f46b6a25a
--- /dev/null
+++ b/python/pyspark/sql/tests/test_serde.py
@@ -0,0 +1,139 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+import shutil
+import tempfile
+import time
+
+from pyspark.sql import Row
+from pyspark.sql.functions import lit
+from pyspark.sql.types import *
+from pyspark.testing.sqlutils import ReusedSQLTestCase, UTCOffsetTimezone
+
+
+class SerdeTests(ReusedSQLTestCase):
+
+ def test_serialize_nested_array_and_map(self):
+ d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
+ rdd = self.sc.parallelize(d)
+ df = self.spark.createDataFrame(rdd)
+ row = df.head()
+ self.assertEqual(1, len(row.l))
+ self.assertEqual(1, row.l[0].a)
+ self.assertEqual("2", row.d["key"].d)
+
+ l = df.rdd.map(lambda x: x.l).first()
+ self.assertEqual(1, len(l))
+ self.assertEqual('s', l[0].b)
+
+ d = df.rdd.map(lambda x: x.d).first()
+ self.assertEqual(1, len(d))
+ self.assertEqual(1.0, d["key"].c)
+
+ row = df.rdd.map(lambda x: x.d["key"]).first()
+ self.assertEqual(1.0, row.c)
+ self.assertEqual("2", row.d)
+
+ def test_select_null_literal(self):
+ df = self.spark.sql("select null as col")
+ self.assertEqual(Row(col=None), df.first())
+
+ def test_struct_in_map(self):
+ d = [Row(m={Row(i=1): Row(s="")})]
+ df = self.sc.parallelize(d).toDF()
+ k, v = list(df.head().m.items())[0]
+ self.assertEqual(1, k.i)
+ self.assertEqual("", v.s)
+
+ def test_filter_with_datetime(self):
+ time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
+ date = time.date()
+ row = Row(date=date, time=time)
+ df = self.spark.createDataFrame([row])
+ self.assertEqual(1, df.filter(df.date == date).count())
+ self.assertEqual(1, df.filter(df.time == time).count())
+ self.assertEqual(0, df.filter(df.date > date).count())
+ self.assertEqual(0, df.filter(df.time > time).count())
+
+ def test_filter_with_datetime_timezone(self):
+ dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0))
+ dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1))
+ row = Row(date=dt1)
+ df = self.spark.createDataFrame([row])
+ self.assertEqual(0, df.filter(df.date == dt2).count())
+ self.assertEqual(1, df.filter(df.date > dt2).count())
+ self.assertEqual(0, df.filter(df.date < dt2).count())
+
+ def test_time_with_timezone(self):
+ day = datetime.date.today()
+ now = datetime.datetime.now()
+ ts = time.mktime(now.timetuple())
+ # class in __main__ is not serializable
+ from pyspark.testing.sqlutils import UTCOffsetTimezone
+ utc = UTCOffsetTimezone()
+ utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds
+ # add microseconds to utcnow (keeping year,month,day,hour,minute,second)
+ utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
+ df = self.spark.createDataFrame([(day, now, utcnow)])
+ day1, now1, utcnow1 = df.first()
+ self.assertEqual(day1, day)
+ self.assertEqual(now, now1)
+ self.assertEqual(now, utcnow1)
+
+ # regression test for SPARK-19561
+ def test_datetime_at_epoch(self):
+ epoch = datetime.datetime.fromtimestamp(0)
+ df = self.spark.createDataFrame([Row(date=epoch)])
+ first = df.select('date', lit(epoch).alias('lit_date')).first()
+ self.assertEqual(first['date'], epoch)
+ self.assertEqual(first['lit_date'], epoch)
+
+ def test_decimal(self):
+ from decimal import Decimal
+ schema = StructType([StructField("decimal", DecimalType(10, 5))])
+ df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema)
+ row = df.select(df.decimal + 1).first()
+ self.assertEqual(row[0], Decimal("4.14159"))
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.write.parquet(tmpPath)
+ df2 = self.spark.read.parquet(tmpPath)
+ row = df2.first()
+ self.assertEqual(row[0], Decimal("3.14159"))
+
+ def test_BinaryType_serialization(self):
+ # Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808
+ # The empty bytearray is test for SPARK-21534.
+ schema = StructType([StructField('mybytes', BinaryType())])
+ data = [[bytearray(b'here is my data')],
+ [bytearray(b'and here is some more')],
+ [bytearray(b'')]]
+ df = self.spark.createDataFrame(data, schema=schema)
+ df.collect()
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.test_serde import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py
new file mode 100644
index 0000000000000..c6b9e0b2ca554
--- /dev/null
+++ b/python/pyspark/sql/tests/test_session.py
@@ -0,0 +1,321 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+
+from pyspark import SparkConf, SparkContext
+from pyspark.sql import SparkSession, SQLContext, Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import PySparkTestCase
+
+
+class SparkSessionTests(ReusedSQLTestCase):
+ def test_sqlcontext_reuses_sparksession(self):
+ sqlContext1 = SQLContext(self.sc)
+ sqlContext2 = SQLContext(self.sc)
+ self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)
+
+
+class SparkSessionTests1(ReusedSQLTestCase):
+
+ # We can't include this test into SQLTests because we will stop class's SparkContext and cause
+ # other tests failed.
+ def test_sparksession_with_stopped_sparkcontext(self):
+ self.sc.stop()
+ sc = SparkContext('local[4]', self.sc.appName)
+ spark = SparkSession.builder.getOrCreate()
+ try:
+ df = spark.createDataFrame([(1, 2)], ["c", "c"])
+ df.collect()
+ finally:
+ spark.stop()
+ sc.stop()
+
+
+class SparkSessionTests2(PySparkTestCase):
+
+ # This test is separate because it's closely related with session's start and stop.
+ # See SPARK-23228.
+ def test_set_jvm_default_session(self):
+ spark = SparkSession.builder.getOrCreate()
+ try:
+ self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
+ finally:
+ spark.stop()
+ self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty())
+
+ def test_jvm_default_session_already_set(self):
+ # Here, we assume there is the default session already set in JVM.
+ jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc())
+ self.sc._jvm.SparkSession.setDefaultSession(jsession)
+
+ spark = SparkSession.builder.getOrCreate()
+ try:
+ self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
+ # The session should be the same with the exiting one.
+ self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get()))
+ finally:
+ spark.stop()
+
+
+class SparkSessionTests3(unittest.TestCase):
+
+ def test_active_session(self):
+ spark = SparkSession.builder \
+ .master("local") \
+ .getOrCreate()
+ try:
+ activeSession = SparkSession.getActiveSession()
+ df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name'])
+ self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')])
+ finally:
+ spark.stop()
+
+ def test_get_active_session_when_no_active_session(self):
+ active = SparkSession.getActiveSession()
+ self.assertEqual(active, None)
+ spark = SparkSession.builder \
+ .master("local") \
+ .getOrCreate()
+ active = SparkSession.getActiveSession()
+ self.assertEqual(active, spark)
+ spark.stop()
+ active = SparkSession.getActiveSession()
+ self.assertEqual(active, None)
+
+ def test_SparkSession(self):
+ spark = SparkSession.builder \
+ .master("local") \
+ .config("some-config", "v2") \
+ .getOrCreate()
+ try:
+ self.assertEqual(spark.conf.get("some-config"), "v2")
+ self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2")
+ self.assertEqual(spark.version, spark.sparkContext.version)
+ spark.sql("CREATE DATABASE test_db")
+ spark.catalog.setCurrentDatabase("test_db")
+ self.assertEqual(spark.catalog.currentDatabase(), "test_db")
+ spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet")
+ self.assertEqual(spark.table("table1").columns, ['name', 'age'])
+ self.assertEqual(spark.range(3).count(), 3)
+ finally:
+ spark.stop()
+
+ def test_global_default_session(self):
+ spark = SparkSession.builder \
+ .master("local") \
+ .getOrCreate()
+ try:
+ self.assertEqual(SparkSession.builder.getOrCreate(), spark)
+ finally:
+ spark.stop()
+
+ def test_default_and_active_session(self):
+ spark = SparkSession.builder \
+ .master("local") \
+ .getOrCreate()
+ activeSession = spark._jvm.SparkSession.getActiveSession()
+ defaultSession = spark._jvm.SparkSession.getDefaultSession()
+ try:
+ self.assertEqual(activeSession, defaultSession)
+ finally:
+ spark.stop()
+
+ def test_config_option_propagated_to_existing_session(self):
+ session1 = SparkSession.builder \
+ .master("local") \
+ .config("spark-config1", "a") \
+ .getOrCreate()
+ self.assertEqual(session1.conf.get("spark-config1"), "a")
+ session2 = SparkSession.builder \
+ .config("spark-config1", "b") \
+ .getOrCreate()
+ try:
+ self.assertEqual(session1, session2)
+ self.assertEqual(session1.conf.get("spark-config1"), "b")
+ finally:
+ session1.stop()
+
+ def test_new_session(self):
+ session = SparkSession.builder \
+ .master("local") \
+ .getOrCreate()
+ newSession = session.newSession()
+ try:
+ self.assertNotEqual(session, newSession)
+ finally:
+ session.stop()
+ newSession.stop()
+
+ def test_create_new_session_if_old_session_stopped(self):
+ session = SparkSession.builder \
+ .master("local") \
+ .getOrCreate()
+ session.stop()
+ newSession = SparkSession.builder \
+ .master("local") \
+ .getOrCreate()
+ try:
+ self.assertNotEqual(session, newSession)
+ finally:
+ newSession.stop()
+
+ def test_active_session_with_None_and_not_None_context(self):
+ from pyspark.context import SparkContext
+ from pyspark.conf import SparkConf
+ sc = None
+ session = None
+ try:
+ sc = SparkContext._active_spark_context
+ self.assertEqual(sc, None)
+ activeSession = SparkSession.getActiveSession()
+ self.assertEqual(activeSession, None)
+ sparkConf = SparkConf()
+ sc = SparkContext.getOrCreate(sparkConf)
+ activeSession = sc._jvm.SparkSession.getActiveSession()
+ self.assertFalse(activeSession.isDefined())
+ session = SparkSession(sc)
+ activeSession = sc._jvm.SparkSession.getActiveSession()
+ self.assertTrue(activeSession.isDefined())
+ activeSession2 = SparkSession.getActiveSession()
+ self.assertNotEqual(activeSession2, None)
+ finally:
+ if session is not None:
+ session.stop()
+ if sc is not None:
+ sc.stop()
+
+
+class SparkSessionTests4(ReusedSQLTestCase):
+
+ def test_get_active_session_after_create_dataframe(self):
+ session2 = None
+ try:
+ activeSession1 = SparkSession.getActiveSession()
+ session1 = self.spark
+ self.assertEqual(session1, activeSession1)
+ session2 = self.spark.newSession()
+ activeSession2 = SparkSession.getActiveSession()
+ self.assertEqual(session1, activeSession2)
+ self.assertNotEqual(session2, activeSession2)
+ session2.createDataFrame([(1, 'Alice')], ['age', 'name'])
+ activeSession3 = SparkSession.getActiveSession()
+ self.assertEqual(session2, activeSession3)
+ session1.createDataFrame([(1, 'Alice')], ['age', 'name'])
+ activeSession4 = SparkSession.getActiveSession()
+ self.assertEqual(session1, activeSession4)
+ finally:
+ if session2 is not None:
+ session2.stop()
+
+
+class SparkSessionBuilderTests(unittest.TestCase):
+
+ def test_create_spark_context_first_then_spark_session(self):
+ sc = None
+ session = None
+ try:
+ conf = SparkConf().set("key1", "value1")
+ sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf)
+ session = SparkSession.builder.config("key2", "value2").getOrCreate()
+
+ self.assertEqual(session.conf.get("key1"), "value1")
+ self.assertEqual(session.conf.get("key2"), "value2")
+ self.assertEqual(session.sparkContext, sc)
+
+ self.assertFalse(sc.getConf().contains("key2"))
+ self.assertEqual(sc.getConf().get("key1"), "value1")
+ finally:
+ if session is not None:
+ session.stop()
+ if sc is not None:
+ sc.stop()
+
+ def test_another_spark_session(self):
+ session1 = None
+ session2 = None
+ try:
+ session1 = SparkSession.builder.config("key1", "value1").getOrCreate()
+ session2 = SparkSession.builder.config("key2", "value2").getOrCreate()
+
+ self.assertEqual(session1.conf.get("key1"), "value1")
+ self.assertEqual(session2.conf.get("key1"), "value1")
+ self.assertEqual(session1.conf.get("key2"), "value2")
+ self.assertEqual(session2.conf.get("key2"), "value2")
+ self.assertEqual(session1.sparkContext, session2.sparkContext)
+
+ self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1")
+ self.assertFalse(session1.sparkContext.getConf().contains("key2"))
+ finally:
+ if session1 is not None:
+ session1.stop()
+ if session2 is not None:
+ session2.stop()
+
+
+class SparkExtensionsTest(unittest.TestCase):
+ # These tests are separate because it uses 'spark.sql.extensions' which is
+ # static and immutable. This can't be set or unset, for example, via `spark.conf`.
+
+ @classmethod
+ def setUpClass(cls):
+ import glob
+ from pyspark.find_spark_home import _find_spark_home
+
+ SPARK_HOME = _find_spark_home()
+ filename_pattern = (
+ "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
+ "SparkSessionExtensionSuite.class")
+ if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
+ raise unittest.SkipTest(
+ "'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
+ "available. Will skip the related tests.")
+
+ # Note that 'spark.sql.extensions' is a static immutable configuration.
+ cls.spark = SparkSession.builder \
+ .master("local[4]") \
+ .appName(cls.__name__) \
+ .config(
+ "spark.sql.extensions",
+ "org.apache.spark.sql.MyExtensions") \
+ .getOrCreate()
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.spark.stop()
+
+ def test_use_custom_class_for_extensions(self):
+ self.assertTrue(
+ self.spark._jsparkSession.sessionState().planner().strategies().contains(
+ self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)),
+ "MySparkStrategy not found in active planner strategies")
+ self.assertTrue(
+ self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains(
+ self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)),
+ "MyRule not found in extended resolution rules")
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_session import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py
new file mode 100644
index 0000000000000..4b71759f74a55
--- /dev/null
+++ b/python/pyspark/sql/tests/test_streaming.py
@@ -0,0 +1,567 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import tempfile
+import time
+
+from pyspark.sql.functions import lit
+from pyspark.sql.types import *
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class StreamingTests(ReusedSQLTestCase):
+
+ def test_stream_trigger(self):
+ df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+
+ # Should take at least one arg
+ try:
+ df.writeStream.trigger()
+ except ValueError:
+ pass
+
+ # Should not take multiple args
+ try:
+ df.writeStream.trigger(once=True, processingTime='5 seconds')
+ except ValueError:
+ pass
+
+ # Should not take multiple args
+ try:
+ df.writeStream.trigger(processingTime='5 seconds', continuous='1 second')
+ except ValueError:
+ pass
+
+ # Should take only keyword args
+ try:
+ df.writeStream.trigger('5 seconds')
+ self.fail("Should have thrown an exception")
+ except TypeError:
+ pass
+
+ def test_stream_read_options(self):
+ schema = StructType([StructField("data", StringType(), False)])
+ df = self.spark.readStream\
+ .format('text')\
+ .option('path', 'python/test_support/sql/streaming')\
+ .schema(schema)\
+ .load()
+ self.assertTrue(df.isStreaming)
+ self.assertEqual(df.schema.simpleString(), "struct")
+
+ def test_stream_read_options_overwrite(self):
+ bad_schema = StructType([StructField("test", IntegerType(), False)])
+ schema = StructType([StructField("data", StringType(), False)])
+ df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \
+ .schema(bad_schema)\
+ .load(path='python/test_support/sql/streaming', schema=schema, format='text')
+ self.assertTrue(df.isStreaming)
+ self.assertEqual(df.schema.simpleString(), "struct")
+
+ def test_stream_save_options(self):
+ df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \
+ .withColumn('id', lit(1))
+ for q in self.spark._wrapped.streams.active:
+ q.stop()
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \
+ .format('parquet').partitionBy('id').outputMode('append').option('path', out).start()
+ try:
+ self.assertEqual(q.name, 'this_query')
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ output_files = []
+ for _, _, files in os.walk(out):
+ output_files.extend([f for f in files if not f.startswith('.')])
+ self.assertTrue(len(output_files) > 0)
+ self.assertTrue(len(os.listdir(chk)) > 0)
+ finally:
+ q.stop()
+ shutil.rmtree(tmpPath)
+
+ def test_stream_save_options_overwrite(self):
+ df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+ for q in self.spark._wrapped.streams.active:
+ q.stop()
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ fake1 = os.path.join(tmpPath, 'fake1')
+ fake2 = os.path.join(tmpPath, 'fake2')
+ q = df.writeStream.option('checkpointLocation', fake1)\
+ .format('memory').option('path', fake2) \
+ .queryName('fake_query').outputMode('append') \
+ .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
+
+ try:
+ self.assertEqual(q.name, 'this_query')
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ output_files = []
+ for _, _, files in os.walk(out):
+ output_files.extend([f for f in files if not f.startswith('.')])
+ self.assertTrue(len(output_files) > 0)
+ self.assertTrue(len(os.listdir(chk)) > 0)
+ self.assertFalse(os.path.isdir(fake1)) # should not have been created
+ self.assertFalse(os.path.isdir(fake2)) # should not have been created
+ finally:
+ q.stop()
+ shutil.rmtree(tmpPath)
+
+ def test_stream_status_and_progress(self):
+ df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+ for q in self.spark._wrapped.streams.active:
+ q.stop()
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+
+ def func(x):
+ time.sleep(1)
+ return x
+
+ from pyspark.sql.functions import col, udf
+ sleep_udf = udf(func)
+
+ # Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there
+ # were no updates.
+ q = df.select(sleep_udf(col("value")).alias('value')).writeStream \
+ .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
+ try:
+ # "lastProgress" will return None in most cases. However, as it may be flaky when
+ # Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress"
+ # may throw error with a high chance and make this test flaky, so we should still be
+ # able to detect broken codes.
+ q.lastProgress
+
+ q.processAllAvailable()
+ lastProgress = q.lastProgress
+ recentProgress = q.recentProgress
+ status = q.status
+ self.assertEqual(lastProgress['name'], q.name)
+ self.assertEqual(lastProgress['id'], q.id)
+ self.assertTrue(any(p == lastProgress for p in recentProgress))
+ self.assertTrue(
+ "message" in status and
+ "isDataAvailable" in status and
+ "isTriggerActive" in status)
+ finally:
+ q.stop()
+ shutil.rmtree(tmpPath)
+
+ def test_stream_await_termination(self):
+ df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+ for q in self.spark._wrapped.streams.active:
+ q.stop()
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ q = df.writeStream\
+ .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
+ try:
+ self.assertTrue(q.isActive)
+ try:
+ q.awaitTermination("hello")
+ self.fail("Expected a value exception")
+ except ValueError:
+ pass
+ now = time.time()
+ # test should take at least 2 seconds
+ res = q.awaitTermination(2.6)
+ duration = time.time() - now
+ self.assertTrue(duration >= 2)
+ self.assertFalse(res)
+ finally:
+ q.stop()
+ shutil.rmtree(tmpPath)
+
+ def test_stream_exception(self):
+ sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+ sq = sdf.writeStream.format('memory').queryName('query_explain').start()
+ try:
+ sq.processAllAvailable()
+ self.assertEqual(sq.exception(), None)
+ finally:
+ sq.stop()
+
+ from pyspark.sql.functions import col, udf
+ from pyspark.sql.utils import StreamingQueryException
+ bad_udf = udf(lambda x: 1 / 0)
+ sq = sdf.select(bad_udf(col("value")))\
+ .writeStream\
+ .format('memory')\
+ .queryName('this_query')\
+ .start()
+ try:
+ # Process some data to fail the query
+ sq.processAllAvailable()
+ self.fail("bad udf should fail the query")
+ except StreamingQueryException as e:
+ # This is expected
+ self.assertTrue("ZeroDivisionError" in e.desc)
+ finally:
+ sq.stop()
+ self.assertTrue(type(sq.exception()) is StreamingQueryException)
+ self.assertTrue("ZeroDivisionError" in sq.exception().desc)
+
+ def test_query_manager_await_termination(self):
+ df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+ for q in self.spark._wrapped.streams.active:
+ q.stop()
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ self.assertTrue(df.isStreaming)
+ out = os.path.join(tmpPath, 'out')
+ chk = os.path.join(tmpPath, 'chk')
+ q = df.writeStream\
+ .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
+ try:
+ self.assertTrue(q.isActive)
+ try:
+ self.spark._wrapped.streams.awaitAnyTermination("hello")
+ self.fail("Expected a value exception")
+ except ValueError:
+ pass
+ now = time.time()
+ # test should take at least 2 seconds
+ res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
+ duration = time.time() - now
+ self.assertTrue(duration >= 2)
+ self.assertFalse(res)
+ finally:
+ q.stop()
+ shutil.rmtree(tmpPath)
+
+ class ForeachWriterTester:
+
+ def __init__(self, spark):
+ self.spark = spark
+
+ def write_open_event(self, partitionId, epochId):
+ self._write_event(
+ self.open_events_dir,
+ {'partition': partitionId, 'epoch': epochId})
+
+ def write_process_event(self, row):
+ self._write_event(self.process_events_dir, {'value': 'text'})
+
+ def write_close_event(self, error):
+ self._write_event(self.close_events_dir, {'error': str(error)})
+
+ def write_input_file(self):
+ self._write_event(self.input_dir, "text")
+
+ def open_events(self):
+ return self._read_events(self.open_events_dir, 'partition INT, epoch INT')
+
+ def process_events(self):
+ return self._read_events(self.process_events_dir, 'value STRING')
+
+ def close_events(self):
+ return self._read_events(self.close_events_dir, 'error STRING')
+
+ def run_streaming_query_on_writer(self, writer, num_files):
+ self._reset()
+ try:
+ sdf = self.spark.readStream.format('text').load(self.input_dir)
+ sq = sdf.writeStream.foreach(writer).start()
+ for i in range(num_files):
+ self.write_input_file()
+ sq.processAllAvailable()
+ finally:
+ self.stop_all()
+
+ def assert_invalid_writer(self, writer, msg=None):
+ self._reset()
+ try:
+ sdf = self.spark.readStream.format('text').load(self.input_dir)
+ sq = sdf.writeStream.foreach(writer).start()
+ self.write_input_file()
+ sq.processAllAvailable()
+ self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected
+ except Exception as e:
+ if msg:
+ assert msg in str(e), "%s not in %s" % (msg, str(e))
+
+ finally:
+ self.stop_all()
+
+ def stop_all(self):
+ for q in self.spark._wrapped.streams.active:
+ q.stop()
+
+ def _reset(self):
+ self.input_dir = tempfile.mkdtemp()
+ self.open_events_dir = tempfile.mkdtemp()
+ self.process_events_dir = tempfile.mkdtemp()
+ self.close_events_dir = tempfile.mkdtemp()
+
+ def _read_events(self, dir, json):
+ rows = self.spark.read.schema(json).json(dir).collect()
+ dicts = [row.asDict() for row in rows]
+ return dicts
+
+ def _write_event(self, dir, event):
+ import uuid
+ with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f:
+ f.write("%s\n" % str(event))
+
+ def __getstate__(self):
+ return (self.open_events_dir, self.process_events_dir, self.close_events_dir)
+
+ def __setstate__(self, state):
+ self.open_events_dir, self.process_events_dir, self.close_events_dir = state
+
+ # Those foreach tests are failed in Python 3.6 and macOS High Sierra by defined rules
+ # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html
+ # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES.
+ def test_streaming_foreach_with_simple_function(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ def foreach_func(row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(foreach_func, 2)
+ self.assertEqual(len(tester.process_events()), 2)
+
+ def test_streaming_foreach_with_basic_open_process_close(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partitionId, epochId):
+ tester.write_open_event(partitionId, epochId)
+ return True
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+ open_events = tester.open_events()
+ self.assertEqual(len(open_events), 2)
+ self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1})
+
+ self.assertEqual(len(tester.process_events()), 2)
+
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 2)
+ self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
+
+ def test_streaming_foreach_with_open_returning_false(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partition_id, epoch_id):
+ tester.write_open_event(partition_id, epoch_id)
+ return False
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+ self.assertEqual(len(tester.open_events()), 2)
+
+ self.assertEqual(len(tester.process_events()), 0) # no row was processed
+
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 2)
+ self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
+
+ def test_streaming_foreach_without_open_method(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+ self.assertEqual(len(tester.open_events()), 0) # no open events
+ self.assertEqual(len(tester.process_events()), 2)
+ self.assertEqual(len(tester.close_events()), 2)
+
+ def test_streaming_foreach_without_close_method(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partition_id, epoch_id):
+ tester.write_open_event(partition_id, epoch_id)
+ return True
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+ self.assertEqual(len(tester.open_events()), 2) # no open events
+ self.assertEqual(len(tester.process_events()), 2)
+ self.assertEqual(len(tester.close_events()), 0)
+
+ def test_streaming_foreach_without_open_and_close_methods(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def process(self, row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+ self.assertEqual(len(tester.open_events()), 0) # no open events
+ self.assertEqual(len(tester.process_events()), 2)
+ self.assertEqual(len(tester.close_events()), 0)
+
+ def test_streaming_foreach_with_process_throwing_error(self):
+ from pyspark.sql.utils import StreamingQueryException
+
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def process(self, row):
+ raise Exception("test error")
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ try:
+ tester.run_streaming_query_on_writer(ForeachWriter(), 1)
+ self.fail("bad writer did not fail the query") # this is not expected
+ except StreamingQueryException as e:
+ # TODO: Verify whether original error message is inside the exception
+ pass
+
+ self.assertEqual(len(tester.process_events()), 0) # no row was processed
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 1)
+ # TODO: Verify whether original error message is inside the exception
+
+ def test_streaming_foreach_with_invalid_writers(self):
+
+ tester = self.ForeachWriterTester(self.spark)
+
+ def func_with_iterator_input(iter):
+ for x in iter:
+ print(x)
+
+ tester.assert_invalid_writer(func_with_iterator_input)
+
+ class WriterWithoutProcess:
+ def open(self, partition):
+ pass
+
+ tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'")
+
+ class WriterWithNonCallableProcess():
+ process = True
+
+ tester.assert_invalid_writer(WriterWithNonCallableProcess(),
+ "'process' in provided object is not callable")
+
+ class WriterWithNoParamProcess():
+ def process(self):
+ pass
+
+ tester.assert_invalid_writer(WriterWithNoParamProcess())
+
+ # Abstract class for tests below
+ class WithProcess():
+ def process(self, row):
+ pass
+
+ class WriterWithNonCallableOpen(WithProcess):
+ open = True
+
+ tester.assert_invalid_writer(WriterWithNonCallableOpen(),
+ "'open' in provided object is not callable")
+
+ class WriterWithNoParamOpen(WithProcess):
+ def open(self):
+ pass
+
+ tester.assert_invalid_writer(WriterWithNoParamOpen())
+
+ class WriterWithNonCallableClose(WithProcess):
+ close = True
+
+ tester.assert_invalid_writer(WriterWithNonCallableClose(),
+ "'close' in provided object is not callable")
+
+ def test_streaming_foreachBatch(self):
+ q = None
+ collected = dict()
+
+ def collectBatch(batch_df, batch_id):
+ collected[batch_id] = batch_df.collect()
+
+ try:
+ df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+ q = df.writeStream.foreachBatch(collectBatch).start()
+ q.processAllAvailable()
+ self.assertTrue(0 in collected)
+ self.assertTrue(len(collected[0]), 2)
+ finally:
+ if q:
+ q.stop()
+
+ def test_streaming_foreachBatch_propagates_python_errors(self):
+ from pyspark.sql.utils import StreamingQueryException
+
+ q = None
+
+ def collectBatch(df, id):
+ raise Exception("this should fail the query")
+
+ try:
+ df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+ q = df.writeStream.foreachBatch(collectBatch).start()
+ q.processAllAvailable()
+ self.fail("Expected a failure")
+ except StreamingQueryException as e:
+ self.assertTrue("this should fail" in str(e))
+ finally:
+ if q:
+ q.stop()
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.test_streaming import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
new file mode 100644
index 0000000000000..fb673f2a385ef
--- /dev/null
+++ b/python/pyspark/sql/tests/test_types.py
@@ -0,0 +1,945 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import array
+import ctypes
+import datetime
+import os
+import pickle
+import sys
+import unittest
+
+from pyspark.sql import Row
+from pyspark.sql.functions import UserDefinedFunction
+from pyspark.sql.types import *
+from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings, \
+ _array_unsigned_int_typecode_ctype_mappings, _infer_type, _make_type_verifier, _merge_type
+from pyspark.testing.sqlutils import ReusedSQLTestCase, ExamplePointUDT, PythonOnlyUDT, \
+ ExamplePoint, PythonOnlyPoint, MyObject
+
+
+class TypesTests(ReusedSQLTestCase):
+
+ def test_apply_schema_to_row(self):
+ df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
+ df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
+ self.assertEqual(df.collect(), df2.collect())
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
+ df3 = self.spark.createDataFrame(rdd, df.schema)
+ self.assertEqual(10, df3.count())
+
+ def test_infer_schema_to_local(self):
+ input = [{"a": 1}, {"b": "coffee"}]
+ rdd = self.sc.parallelize(input)
+ df = self.spark.createDataFrame(input)
+ df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
+ self.assertEqual(df.schema, df2.schema)
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
+ df3 = self.spark.createDataFrame(rdd, df.schema)
+ self.assertEqual(10, df3.count())
+
+ def test_apply_schema_to_dict_and_rows(self):
+ schema = StructType().add("b", StringType()).add("a", IntegerType())
+ input = [{"a": 1}, {"b": "coffee"}]
+ rdd = self.sc.parallelize(input)
+ for verify in [False, True]:
+ df = self.spark.createDataFrame(input, schema, verifySchema=verify)
+ df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+ self.assertEqual(df.schema, df2.schema)
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
+ df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
+ self.assertEqual(10, df3.count())
+ input = [Row(a=x, b=str(x)) for x in range(10)]
+ df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
+ self.assertEqual(10, df4.count())
+
+ def test_create_dataframe_schema_mismatch(self):
+ input = [Row(a=1)]
+ rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
+ schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
+ df = self.spark.createDataFrame(rdd, schema)
+ self.assertRaises(Exception, lambda: df.show())
+
+ def test_infer_schema(self):
+ d = [Row(l=[], d={}, s=None),
+ Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
+ rdd = self.sc.parallelize(d)
+ df = self.spark.createDataFrame(rdd)
+ self.assertEqual([], df.rdd.map(lambda r: r.l).first())
+ self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
+
+ with self.tempView("test"):
+ df.createOrReplaceTempView("test")
+ result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'")
+ self.assertEqual(1, result.head()[0])
+
+ df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
+ self.assertEqual(df.schema, df2.schema)
+ self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
+ self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
+
+ with self.tempView("test2"):
+ df2.createOrReplaceTempView("test2")
+ result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
+ self.assertEqual(1, result.head()[0])
+
+ def test_infer_schema_specification(self):
+ from decimal import Decimal
+
+ class A(object):
+ def __init__(self):
+ self.a = 1
+
+ data = [
+ True,
+ 1,
+ "a",
+ u"a",
+ datetime.date(1970, 1, 1),
+ datetime.datetime(1970, 1, 1, 0, 0),
+ 1.0,
+ array.array("d", [1]),
+ [1],
+ (1, ),
+ {"a": 1},
+ bytearray(1),
+ Decimal(1),
+ Row(a=1),
+ Row("a")(1),
+ A(),
+ ]
+
+ df = self.spark.createDataFrame([data])
+ actual = list(map(lambda x: x.dataType.simpleString(), df.schema))
+ expected = [
+ 'boolean',
+ 'bigint',
+ 'string',
+ 'string',
+ 'date',
+ 'timestamp',
+ 'double',
+ 'array',
+ 'array',
+ 'struct<_1:bigint>',
+ 'map',
+ 'binary',
+ 'decimal(38,18)',
+ 'struct',
+ 'struct',
+ 'struct',
+ ]
+ self.assertEqual(actual, expected)
+
+ actual = list(df.first())
+ expected = [
+ True,
+ 1,
+ 'a',
+ u"a",
+ datetime.date(1970, 1, 1),
+ datetime.datetime(1970, 1, 1, 0, 0),
+ 1.0,
+ [1.0],
+ [1],
+ Row(_1=1),
+ {"a": 1},
+ bytearray(b'\x00'),
+ Decimal('1.000000000000000000'),
+ Row(a=1),
+ Row(a=1),
+ Row(a=1),
+ ]
+ self.assertEqual(actual, expected)
+
+ def test_infer_schema_not_enough_names(self):
+ df = self.spark.createDataFrame([["a", "b"]], ["col1"])
+ self.assertEqual(df.columns, ['col1', '_2'])
+
+ def test_infer_schema_fails(self):
+ with self.assertRaisesRegexp(TypeError, 'field a'):
+ self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
+ schema=["a", "b"], samplingRatio=0.99)
+
+ def test_infer_nested_schema(self):
+ NestedRow = Row("f1", "f2")
+ nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
+ NestedRow([2, 3], {"row2": 2.0})])
+ df = self.spark.createDataFrame(nestedRdd1)
+ self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
+
+ nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
+ NestedRow([[2, 3], [3, 4]], [2, 3])])
+ df = self.spark.createDataFrame(nestedRdd2)
+ self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
+
+ from collections import namedtuple
+ CustomRow = namedtuple('CustomRow', 'field1 field2')
+ rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
+ CustomRow(field1=2, field2="row2"),
+ CustomRow(field1=3, field2="row3")])
+ df = self.spark.createDataFrame(rdd)
+ self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
+
+ def test_create_dataframe_from_dict_respects_schema(self):
+ df = self.spark.createDataFrame([{'a': 1}], ["b"])
+ self.assertEqual(df.columns, ['b'])
+
+ def test_create_dataframe_from_objects(self):
+ data = [MyObject(1, "1"), MyObject(2, "2")]
+ df = self.spark.createDataFrame(data)
+ self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
+ self.assertEqual(df.first(), Row(key=1, value="1"))
+
+ def test_apply_schema(self):
+ from datetime import date, datetime
+ rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
+ date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
+ {"a": 1}, (2,), [1, 2, 3], None)])
+ schema = StructType([
+ StructField("byte1", ByteType(), False),
+ StructField("byte2", ByteType(), False),
+ StructField("short1", ShortType(), False),
+ StructField("short2", ShortType(), False),
+ StructField("int1", IntegerType(), False),
+ StructField("float1", FloatType(), False),
+ StructField("date1", DateType(), False),
+ StructField("time1", TimestampType(), False),
+ StructField("map1", MapType(StringType(), IntegerType(), False), False),
+ StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
+ StructField("list1", ArrayType(ByteType(), False), False),
+ StructField("null1", DoubleType(), True)])
+ df = self.spark.createDataFrame(rdd, schema)
+ results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1,
+ x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
+ r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
+ datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+ self.assertEqual(r, results.first())
+
+ with self.tempView("table2"):
+ df.createOrReplaceTempView("table2")
+ r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
+ "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
+ "float1 + 1.5 as float1 FROM table2").first()
+
+ self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
+
+ def test_convert_row_to_dict(self):
+ row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
+ self.assertEqual(1, row.asDict()['l'][0].a)
+ df = self.sc.parallelize([row]).toDF()
+
+ with self.tempView("test"):
+ df.createOrReplaceTempView("test")
+ row = self.spark.sql("select l, d from test").head()
+ self.assertEqual(1, row.asDict()["l"][0].a)
+ self.assertEqual(1.0, row.asDict()['d']['key'].c)
+
+ def test_udt(self):
+ from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier
+
+ def check_datatype(datatype):
+ pickled = pickle.loads(pickle.dumps(datatype))
+ assert datatype == pickled
+ scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json())
+ python_datatype = _parse_datatype_json_string(scala_datatype.json())
+ assert datatype == python_datatype
+
+ check_datatype(ExamplePointUDT())
+ structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ check_datatype(structtype_with_udt)
+ p = ExamplePoint(1.0, 2.0)
+ self.assertEqual(_infer_type(p), ExamplePointUDT())
+ _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
+ self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0]))
+
+ check_datatype(PythonOnlyUDT())
+ structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ StructField("point", PythonOnlyUDT(), False)])
+ check_datatype(structtype_with_udt)
+ p = PythonOnlyPoint(1.0, 2.0)
+ self.assertEqual(_infer_type(p), PythonOnlyUDT())
+ _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
+ self.assertRaises(
+ ValueError,
+ lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
+
+ def test_simple_udt_in_df(self):
+ schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
+ df = self.spark.createDataFrame(
+ [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
+ schema=schema)
+ df.collect()
+
+ def test_nested_udt_in_df(self):
+ schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
+ df = self.spark.createDataFrame(
+ [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
+ schema=schema)
+ df.collect()
+
+ schema = StructType().add("key", LongType()).add("val",
+ MapType(LongType(), PythonOnlyUDT()))
+ df = self.spark.createDataFrame(
+ [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
+ schema=schema)
+ df.collect()
+
+ def test_complex_nested_udt_in_df(self):
+ from pyspark.sql.functions import udf
+
+ schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
+ df = self.spark.createDataFrame(
+ [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
+ schema=schema)
+ df.collect()
+
+ gd = df.groupby("key").agg({"val": "collect_list"})
+ gd.collect()
+ udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
+ gd.select(udf(*gd)).collect()
+
+ def test_udt_with_none(self):
+ df = self.spark.range(0, 10, 1, 1)
+
+ def myudf(x):
+ if x > 0:
+ return PythonOnlyPoint(float(x), float(x))
+
+ self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
+ rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
+ self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
+
+ def test_infer_schema_with_udt(self):
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ df = self.spark.createDataFrame([row])
+ schema = df.schema
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), ExamplePointUDT)
+
+ with self.tempView("labeled_point"):
+ df.createOrReplaceTempView("labeled_point")
+ point = self.spark.sql("SELECT point FROM labeled_point").head().point
+ self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+ row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+ df = self.spark.createDataFrame([row])
+ schema = df.schema
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), PythonOnlyUDT)
+
+ with self.tempView("labeled_point"):
+ df.createOrReplaceTempView("labeled_point")
+ point = self.spark.sql("SELECT point FROM labeled_point").head().point
+ self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
+ def test_apply_schema_with_udt(self):
+ row = (1.0, ExamplePoint(1.0, 2.0))
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ df = self.spark.createDataFrame([row], schema)
+ point = df.head().point
+ self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+ row = (1.0, PythonOnlyPoint(1.0, 2.0))
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", PythonOnlyUDT(), False)])
+ df = self.spark.createDataFrame([row], schema)
+ point = df.head().point
+ self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
+ def test_udf_with_udt(self):
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ df = self.spark.createDataFrame([row])
+ self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
+ udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+ self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+ udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
+ self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
+
+ row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+ df = self.spark.createDataFrame([row])
+ self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
+ udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+ self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+ udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
+ self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
+
+ def test_parquet_with_udt(self):
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ df0 = self.spark.createDataFrame([row])
+ output_dir = os.path.join(self.tempdir.name, "labeled_point")
+ df0.write.parquet(output_dir)
+ df1 = self.spark.read.parquet(output_dir)
+ point = df1.head().point
+ self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+ row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+ df0 = self.spark.createDataFrame([row])
+ df0.write.parquet(output_dir, mode='overwrite')
+ df1 = self.spark.read.parquet(output_dir)
+ point = df1.head().point
+ self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
+ def test_union_with_udt(self):
+ row1 = (1.0, ExamplePoint(1.0, 2.0))
+ row2 = (2.0, ExamplePoint(3.0, 4.0))
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ df1 = self.spark.createDataFrame([row1], schema)
+ df2 = self.spark.createDataFrame([row2], schema)
+
+ result = df1.union(df2).orderBy("label").collect()
+ self.assertEqual(
+ result,
+ [
+ Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
+ Row(label=2.0, point=ExamplePoint(3.0, 4.0))
+ ]
+ )
+
+ def test_cast_to_string_with_udt(self):
+ from pyspark.sql.functions import col
+ row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
+ schema = StructType([StructField("point", ExamplePointUDT(), False),
+ StructField("pypoint", PythonOnlyUDT(), False)])
+ df = self.spark.createDataFrame([row], schema)
+
+ result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
+ self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
+
+ def test_struct_type(self):
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ struct2 = StructType([StructField("f1", StringType(), True),
+ StructField("f2", StringType(), True, None)])
+ self.assertEqual(struct1.fieldNames(), struct2.names)
+ self.assertEqual(struct1, struct2)
+
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ struct2 = StructType([StructField("f1", StringType(), True)])
+ self.assertNotEqual(struct1.fieldNames(), struct2.names)
+ self.assertNotEqual(struct1, struct2)
+
+ struct1 = (StructType().add(StructField("f1", StringType(), True))
+ .add(StructField("f2", StringType(), True, None)))
+ struct2 = StructType([StructField("f1", StringType(), True),
+ StructField("f2", StringType(), True, None)])
+ self.assertEqual(struct1.fieldNames(), struct2.names)
+ self.assertEqual(struct1, struct2)
+
+ struct1 = (StructType().add(StructField("f1", StringType(), True))
+ .add(StructField("f2", StringType(), True, None)))
+ struct2 = StructType([StructField("f1", StringType(), True)])
+ self.assertNotEqual(struct1.fieldNames(), struct2.names)
+ self.assertNotEqual(struct1, struct2)
+
+ # Catch exception raised during improper construction
+ self.assertRaises(ValueError, lambda: StructType().add("name"))
+
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ for field in struct1:
+ self.assertIsInstance(field, StructField)
+
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ self.assertEqual(len(struct1), 2)
+
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ self.assertIs(struct1["f1"], struct1.fields[0])
+ self.assertIs(struct1[0], struct1.fields[0])
+ self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
+ self.assertRaises(KeyError, lambda: struct1["f9"])
+ self.assertRaises(IndexError, lambda: struct1[9])
+ self.assertRaises(TypeError, lambda: struct1[9.9])
+
+ def test_parse_datatype_string(self):
+ from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
+ for k, t in _all_atomic_types.items():
+ if t != NullType:
+ self.assertEqual(t(), _parse_datatype_string(k))
+ self.assertEqual(IntegerType(), _parse_datatype_string("int"))
+ self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)"))
+ self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
+ self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
+ self.assertEqual(
+ ArrayType(IntegerType()),
+ _parse_datatype_string("array"))
+ self.assertEqual(
+ MapType(IntegerType(), DoubleType()),
+ _parse_datatype_string("map< int, double >"))
+ self.assertEqual(
+ StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+ _parse_datatype_string("struct"))
+ self.assertEqual(
+ StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+ _parse_datatype_string("a:int, c:double"))
+ self.assertEqual(
+ StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+ _parse_datatype_string("a INT, c DOUBLE"))
+
+ def test_metadata_null(self):
+ schema = StructType([StructField("f1", StringType(), True, None),
+ StructField("f2", StringType(), True, {'a': None})])
+ rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
+ self.spark.createDataFrame(rdd, schema)
+
+ def test_access_nested_types(self):
+ df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
+ self.assertEqual(1, df.select(df.l[0]).first()[0])
+ self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
+ self.assertEqual(1, df.select(df.r.a).first()[0])
+ self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
+ self.assertEqual("v", df.select(df.d["k"]).first()[0])
+ self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
+
+ def test_infer_long_type(self):
+ longrow = [Row(f1='a', f2=100000000000000)]
+ df = self.sc.parallelize(longrow).toDF()
+ self.assertEqual(df.schema.fields[1].dataType, LongType())
+
+ # this saving as Parquet caused issues as well.
+ output_dir = os.path.join(self.tempdir.name, "infer_long_type")
+ df.write.parquet(output_dir)
+ df1 = self.spark.read.parquet(output_dir)
+ self.assertEqual('a', df1.first().f1)
+ self.assertEqual(100000000000000, df1.first().f2)
+
+ self.assertEqual(_infer_type(1), LongType())
+ self.assertEqual(_infer_type(2**10), LongType())
+ self.assertEqual(_infer_type(2**20), LongType())
+ self.assertEqual(_infer_type(2**31 - 1), LongType())
+ self.assertEqual(_infer_type(2**31), LongType())
+ self.assertEqual(_infer_type(2**61), LongType())
+ self.assertEqual(_infer_type(2**71), LongType())
+
+ def test_merge_type(self):
+ self.assertEqual(_merge_type(LongType(), NullType()), LongType())
+ self.assertEqual(_merge_type(NullType(), LongType()), LongType())
+
+ self.assertEqual(_merge_type(LongType(), LongType()), LongType())
+
+ self.assertEqual(_merge_type(
+ ArrayType(LongType()),
+ ArrayType(LongType())
+ ), ArrayType(LongType()))
+ with self.assertRaisesRegexp(TypeError, 'element in array'):
+ _merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
+
+ self.assertEqual(_merge_type(
+ MapType(StringType(), LongType()),
+ MapType(StringType(), LongType())
+ ), MapType(StringType(), LongType()))
+ with self.assertRaisesRegexp(TypeError, 'key of map'):
+ _merge_type(
+ MapType(StringType(), LongType()),
+ MapType(DoubleType(), LongType()))
+ with self.assertRaisesRegexp(TypeError, 'value of map'):
+ _merge_type(
+ MapType(StringType(), LongType()),
+ MapType(StringType(), DoubleType()))
+
+ self.assertEqual(_merge_type(
+ StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
+ StructType([StructField("f1", LongType()), StructField("f2", StringType())])
+ ), StructType([StructField("f1", LongType()), StructField("f2", StringType())]))
+ with self.assertRaisesRegexp(TypeError, 'field f1'):
+ _merge_type(
+ StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
+ StructType([StructField("f1", DoubleType()), StructField("f2", StringType())]))
+
+ self.assertEqual(_merge_type(
+ StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
+ StructType([StructField("f1", StructType([StructField("f2", LongType())]))])
+ ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))]))
+ with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
+ _merge_type(
+ StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
+ StructType([StructField("f1", StructType([StructField("f2", StringType())]))]))
+
+ self.assertEqual(_merge_type(
+ StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]),
+ StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])
+ ), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]))
+ with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
+ _merge_type(
+ StructType([
+ StructField("f1", ArrayType(LongType())),
+ StructField("f2", StringType())]),
+ StructType([
+ StructField("f1", ArrayType(DoubleType())),
+ StructField("f2", StringType())]))
+
+ self.assertEqual(_merge_type(
+ StructType([
+ StructField("f1", MapType(StringType(), LongType())),
+ StructField("f2", StringType())]),
+ StructType([
+ StructField("f1", MapType(StringType(), LongType())),
+ StructField("f2", StringType())])
+ ), StructType([
+ StructField("f1", MapType(StringType(), LongType())),
+ StructField("f2", StringType())]))
+ with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
+ _merge_type(
+ StructType([
+ StructField("f1", MapType(StringType(), LongType())),
+ StructField("f2", StringType())]),
+ StructType([
+ StructField("f1", MapType(StringType(), DoubleType())),
+ StructField("f2", StringType())]))
+
+ self.assertEqual(_merge_type(
+ StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
+ StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])
+ ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]))
+ with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'):
+ _merge_type(
+ StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
+ StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))])
+ )
+
+ # test for SPARK-16542
+ def test_array_types(self):
+ # This test need to make sure that the Scala type selected is at least
+ # as large as the python's types. This is necessary because python's
+ # array types depend on C implementation on the machine. Therefore there
+ # is no machine independent correspondence between python's array types
+ # and Scala types.
+ # See: https://docs.python.org/2/library/array.html
+
+ def assertCollectSuccess(typecode, value):
+ row = Row(myarray=array.array(typecode, [value]))
+ df = self.spark.createDataFrame([row])
+ self.assertEqual(df.first()["myarray"][0], value)
+
+ # supported string types
+ #
+ # String types in python's array are "u" for Py_UNICODE and "c" for char.
+ # "u" will be removed in python 4, and "c" is not supported in python 3.
+ supported_string_types = []
+ if sys.version_info[0] < 4:
+ supported_string_types += ['u']
+ # test unicode
+ assertCollectSuccess('u', u'a')
+ if sys.version_info[0] < 3:
+ supported_string_types += ['c']
+ # test string
+ assertCollectSuccess('c', 'a')
+
+ # supported float and double
+ #
+ # Test max, min, and precision for float and double, assuming IEEE 754
+ # floating-point format.
+ supported_fractional_types = ['f', 'd']
+ assertCollectSuccess('f', ctypes.c_float(1e+38).value)
+ assertCollectSuccess('f', ctypes.c_float(1e-38).value)
+ assertCollectSuccess('f', ctypes.c_float(1.123456).value)
+ assertCollectSuccess('d', sys.float_info.max)
+ assertCollectSuccess('d', sys.float_info.min)
+ assertCollectSuccess('d', sys.float_info.epsilon)
+
+ # supported signed int types
+ #
+ # The size of C types changes with implementation, we need to make sure
+ # that there is no overflow error on the platform running this test.
+ supported_signed_int_types = list(
+ set(_array_signed_int_typecode_ctype_mappings.keys())
+ .intersection(set(_array_type_mappings.keys())))
+ for t in supported_signed_int_types:
+ ctype = _array_signed_int_typecode_ctype_mappings[t]
+ max_val = 2 ** (ctypes.sizeof(ctype) * 8 - 1)
+ assertCollectSuccess(t, max_val - 1)
+ assertCollectSuccess(t, -max_val)
+
+ # supported unsigned int types
+ #
+ # JVM does not have unsigned types. We need to be very careful to make
+ # sure that there is no overflow error.
+ supported_unsigned_int_types = list(
+ set(_array_unsigned_int_typecode_ctype_mappings.keys())
+ .intersection(set(_array_type_mappings.keys())))
+ for t in supported_unsigned_int_types:
+ ctype = _array_unsigned_int_typecode_ctype_mappings[t]
+ assertCollectSuccess(t, 2 ** (ctypes.sizeof(ctype) * 8) - 1)
+
+ # all supported types
+ #
+ # Make sure the types tested above:
+ # 1. are all supported types
+ # 2. cover all supported types
+ supported_types = (supported_string_types +
+ supported_fractional_types +
+ supported_signed_int_types +
+ supported_unsigned_int_types)
+ self.assertEqual(set(supported_types), set(_array_type_mappings.keys()))
+
+ # all unsupported types
+ #
+ # Keys in _array_type_mappings is a complete list of all supported types,
+ # and types not in _array_type_mappings are considered unsupported.
+ # `array.typecodes` are not supported in python 2.
+ if sys.version_info[0] < 3:
+ all_types = set(['c', 'b', 'B', 'u', 'h', 'H', 'i', 'I', 'l', 'L', 'f', 'd'])
+ else:
+ all_types = set(array.typecodes)
+ unsupported_types = all_types - set(supported_types)
+ # test unsupported types
+ for t in unsupported_types:
+ with self.assertRaises(TypeError):
+ a = array.array(t)
+ self.spark.createDataFrame([Row(myarray=a)]).collect()
+
+
+class DataTypeTests(unittest.TestCase):
+ # regression test for SPARK-6055
+ def test_data_type_eq(self):
+ lt = LongType()
+ lt2 = pickle.loads(pickle.dumps(LongType()))
+ self.assertEqual(lt, lt2)
+
+ # regression test for SPARK-7978
+ def test_decimal_type(self):
+ t1 = DecimalType()
+ t2 = DecimalType(10, 2)
+ self.assertTrue(t2 is not t1)
+ self.assertNotEqual(t1, t2)
+ t3 = DecimalType(8)
+ self.assertNotEqual(t2, t3)
+
+ # regression test for SPARK-10392
+ def test_datetype_equal_zero(self):
+ dt = DateType()
+ self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
+
+ # regression test for SPARK-17035
+ def test_timestamp_microsecond(self):
+ tst = TimestampType()
+ self.assertEqual(tst.toInternal(datetime.datetime.max) % 1000000, 999999)
+
+ def test_empty_row(self):
+ row = Row()
+ self.assertEqual(len(row), 0)
+
+ def test_struct_field_type_name(self):
+ struct_field = StructField("a", IntegerType())
+ self.assertRaises(TypeError, struct_field.typeName)
+
+ def test_invalid_create_row(self):
+ row_class = Row("c1", "c2")
+ self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
+
+
+class DataTypeVerificationTests(unittest.TestCase):
+
+ def test_verify_type_exception_msg(self):
+ self.assertRaisesRegexp(
+ ValueError,
+ "test_name",
+ lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None))
+
+ schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))])
+ self.assertRaisesRegexp(
+ TypeError,
+ "field b in field a",
+ lambda: _make_type_verifier(schema)([["data"]]))
+
+ def test_verify_type_ok_nullable(self):
+ obj = None
+ types = [IntegerType(), FloatType(), StringType(), StructType([])]
+ for data_type in types:
+ try:
+ _make_type_verifier(data_type, nullable=True)(obj)
+ except Exception:
+ self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type))
+
+ def test_verify_type_not_nullable(self):
+ import array
+ import datetime
+ import decimal
+
+ schema = StructType([
+ StructField('s', StringType(), nullable=False),
+ StructField('i', IntegerType(), nullable=True)])
+
+ class MyObj:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ # obj, data_type
+ success_spec = [
+ # String
+ ("", StringType()),
+ (u"", StringType()),
+ (1, StringType()),
+ (1.0, StringType()),
+ ([], StringType()),
+ ({}, StringType()),
+
+ # UDT
+ (ExamplePoint(1.0, 2.0), ExamplePointUDT()),
+
+ # Boolean
+ (True, BooleanType()),
+
+ # Byte
+ (-(2**7), ByteType()),
+ (2**7 - 1, ByteType()),
+
+ # Short
+ (-(2**15), ShortType()),
+ (2**15 - 1, ShortType()),
+
+ # Integer
+ (-(2**31), IntegerType()),
+ (2**31 - 1, IntegerType()),
+
+ # Long
+ (2**64, LongType()),
+
+ # Float & Double
+ (1.0, FloatType()),
+ (1.0, DoubleType()),
+
+ # Decimal
+ (decimal.Decimal("1.0"), DecimalType()),
+
+ # Binary
+ (bytearray([1, 2]), BinaryType()),
+
+ # Date/Timestamp
+ (datetime.date(2000, 1, 2), DateType()),
+ (datetime.datetime(2000, 1, 2, 3, 4), DateType()),
+ (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),
+
+ # Array
+ ([], ArrayType(IntegerType())),
+ (["1", None], ArrayType(StringType(), containsNull=True)),
+ ([1, 2], ArrayType(IntegerType())),
+ ((1, 2), ArrayType(IntegerType())),
+ (array.array('h', [1, 2]), ArrayType(IntegerType())),
+
+ # Map
+ ({}, MapType(StringType(), IntegerType())),
+ ({"a": 1}, MapType(StringType(), IntegerType())),
+ ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)),
+
+ # Struct
+ ({"s": "a", "i": 1}, schema),
+ ({"s": "a", "i": None}, schema),
+ ({"s": "a"}, schema),
+ ({"s": "a", "f": 1.0}, schema),
+ (Row(s="a", i=1), schema),
+ (Row(s="a", i=None), schema),
+ (Row(s="a", i=1, f=1.0), schema),
+ (["a", 1], schema),
+ (["a", None], schema),
+ (("a", 1), schema),
+ (MyObj(s="a", i=1), schema),
+ (MyObj(s="a", i=None), schema),
+ (MyObj(s="a"), schema),
+ ]
+
+ # obj, data_type, exception class
+ failure_spec = [
+ # String (match anything but None)
+ (None, StringType(), ValueError),
+
+ # UDT
+ (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
+
+ # Boolean
+ (1, BooleanType(), TypeError),
+ ("True", BooleanType(), TypeError),
+ ([1], BooleanType(), TypeError),
+
+ # Byte
+ (-(2**7) - 1, ByteType(), ValueError),
+ (2**7, ByteType(), ValueError),
+ ("1", ByteType(), TypeError),
+ (1.0, ByteType(), TypeError),
+
+ # Short
+ (-(2**15) - 1, ShortType(), ValueError),
+ (2**15, ShortType(), ValueError),
+
+ # Integer
+ (-(2**31) - 1, IntegerType(), ValueError),
+ (2**31, IntegerType(), ValueError),
+
+ # Float & Double
+ (1, FloatType(), TypeError),
+ (1, DoubleType(), TypeError),
+
+ # Decimal
+ (1.0, DecimalType(), TypeError),
+ (1, DecimalType(), TypeError),
+ ("1.0", DecimalType(), TypeError),
+
+ # Binary
+ (1, BinaryType(), TypeError),
+
+ # Date/Timestamp
+ ("2000-01-02", DateType(), TypeError),
+ (946811040, TimestampType(), TypeError),
+
+ # Array
+ (["1", None], ArrayType(StringType(), containsNull=False), ValueError),
+ ([1, "2"], ArrayType(IntegerType()), TypeError),
+
+ # Map
+ ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError),
+ ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError),
+ ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False),
+ ValueError),
+
+ # Struct
+ ({"s": "a", "i": "1"}, schema, TypeError),
+ (Row(s="a"), schema, ValueError), # Row can't have missing field
+ (Row(s="a", i="1"), schema, TypeError),
+ (["a"], schema, ValueError),
+ (["a", "1"], schema, TypeError),
+ (MyObj(s="a", i="1"), schema, TypeError),
+ (MyObj(s=None, i="1"), schema, ValueError),
+ ]
+
+ # Check success cases
+ for obj, data_type in success_spec:
+ try:
+ _make_type_verifier(data_type, nullable=False)(obj)
+ except Exception:
+ self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type))
+
+ # Check failure cases
+ for obj, data_type, exp in failure_spec:
+ msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp)
+ with self.assertRaises(exp, msg=msg):
+ _make_type_verifier(data_type, nullable=False)(obj)
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_types import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
new file mode 100644
index 0000000000000..ed298f724d551
--- /dev/null
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -0,0 +1,667 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import functools
+import pydoc
+import shutil
+import tempfile
+import unittest
+
+from pyspark import SparkContext
+from pyspark.sql import SparkSession, Column, Row
+from pyspark.sql.functions import UserDefinedFunction
+from pyspark.sql.types import *
+from pyspark.sql.utils import AnalysisException
+from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
+from pyspark.testing.utils import QuietTest
+
+
+class UDFTests(ReusedSQLTestCase):
+
+ def test_udf_with_callable(self):
+ d = [Row(number=i, squared=i**2) for i in range(10)]
+ rdd = self.sc.parallelize(d)
+ data = self.spark.createDataFrame(rdd)
+
+ class PlusFour:
+ def __call__(self, col):
+ if col is not None:
+ return col + 4
+
+ call = PlusFour()
+ pudf = UserDefinedFunction(call, LongType())
+ res = data.select(pudf(data['number']).alias('plus_four'))
+ self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
+ def test_udf_with_partial_function(self):
+ d = [Row(number=i, squared=i**2) for i in range(10)]
+ rdd = self.sc.parallelize(d)
+ data = self.spark.createDataFrame(rdd)
+
+ def some_func(col, param):
+ if col is not None:
+ return col + param
+
+ pfunc = functools.partial(some_func, param=4)
+ pudf = UserDefinedFunction(pfunc, LongType())
+ res = data.select(pudf(data['number']).alias('plus_four'))
+ self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
+ def test_udf(self):
+ self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+ [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
+
+ # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
+ sqlContext = self.spark._wrapped
+ sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
+ [row] = sqlContext.sql("SELECT oneArg('test')").collect()
+ self.assertEqual(row[0], 4)
+
+ def test_udf2(self):
+ with self.tempView("test"):
+ self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
+ self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
+ .createOrReplaceTempView("test")
+ [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
+ self.assertEqual(4, res[0])
+
+ def test_udf3(self):
+ two_args = self.spark.catalog.registerFunction(
+ "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y))
+ self.assertEqual(two_args.deterministic, True)
+ [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], u'5')
+
+ def test_udf_registration_return_type_none(self):
+ two_args = self.spark.catalog.registerFunction(
+ "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None)
+ self.assertEqual(two_args.deterministic, True)
+ [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
+
+ def test_udf_registration_return_type_not_none(self):
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(TypeError, "Invalid returnType"):
+ self.spark.catalog.registerFunction(
+ "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType())
+
+ def test_nondeterministic_udf(self):
+ # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
+ from pyspark.sql.functions import udf
+ import random
+ udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
+ self.assertEqual(udf_random_col.deterministic, False)
+ df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
+ udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
+ [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
+ self.assertEqual(row[0] + 10, row[1])
+
+ def test_nondeterministic_udf2(self):
+ import random
+ from pyspark.sql.functions import udf
+ random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
+ self.assertEqual(random_udf.deterministic, False)
+ random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
+ self.assertEqual(random_udf1.deterministic, False)
+ [row] = self.spark.sql("SELECT randInt()").collect()
+ self.assertEqual(row[0], 6)
+ [row] = self.spark.range(1).select(random_udf1()).collect()
+ self.assertEqual(row[0], 6)
+ [row] = self.spark.range(1).select(random_udf()).collect()
+ self.assertEqual(row[0], 6)
+ # render_doc() reproduces the help() exception without printing output
+ pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
+ pydoc.render_doc(random_udf)
+ pydoc.render_doc(random_udf1)
+ pydoc.render_doc(udf(lambda x: x).asNondeterministic)
+
+ def test_nondeterministic_udf3(self):
+ # regression test for SPARK-23233
+ from pyspark.sql.functions import udf
+ f = udf(lambda x: x)
+ # Here we cache the JVM UDF instance.
+ self.spark.range(1).select(f("id"))
+ # This should reset the cache to set the deterministic status correctly.
+ f = f.asNondeterministic()
+ # Check the deterministic status of udf.
+ df = self.spark.range(1).select(f("id"))
+ deterministic = df._jdf.logicalPlan().projectList().head().deterministic()
+ self.assertFalse(deterministic)
+
+ def test_nondeterministic_udf_in_aggregate(self):
+ from pyspark.sql.functions import udf, sum
+ import random
+ udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
+ df = self.spark.range(10)
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
+ df.groupby('id').agg(sum(udf_random_col())).collect()
+ with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
+ df.agg(sum(udf_random_col())).collect()
+
+ def test_chained_udf(self):
+ self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
+ [row] = self.spark.sql("SELECT double(1)").collect()
+ self.assertEqual(row[0], 2)
+ [row] = self.spark.sql("SELECT double(double(1))").collect()
+ self.assertEqual(row[0], 4)
+ [row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
+ self.assertEqual(row[0], 6)
+
+ def test_single_udf_with_repeated_argument(self):
+ # regression test for SPARK-20685
+ self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
+ row = self.spark.sql("SELECT add(1, 1)").first()
+ self.assertEqual(tuple(row), (2, ))
+
+ def test_multiple_udfs(self):
+ self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
+ [row] = self.spark.sql("SELECT double(1), double(2)").collect()
+ self.assertEqual(tuple(row), (2, 4))
+ [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
+ self.assertEqual(tuple(row), (4, 12))
+ self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
+ [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
+ self.assertEqual(tuple(row), (6, 5))
+
+ def test_udf_in_filter_on_top_of_outer_join(self):
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1)])
+ right = self.spark.createDataFrame([Row(a=1)])
+ df = left.join(right, on='a', how='left_outer')
+ df = df.withColumn('b', udf(lambda x: 'x')(df.a))
+ self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
+
+ def test_udf_in_filter_on_top_of_join(self):
+ # regression test for SPARK-18589
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1)])
+ right = self.spark.createDataFrame([Row(b=1)])
+ f = udf(lambda a, b: a == b, BooleanType())
+ df = left.crossJoin(right).filter(f("a", "b"))
+ self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
+ def test_udf_in_join_condition(self):
+ # regression test for SPARK-25314
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1)])
+ right = self.spark.createDataFrame([Row(b=1)])
+ f = udf(lambda a, b: a == b, BooleanType())
+ df = left.join(right, f("a", "b"))
+ with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
+ df.collect()
+ with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+ self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
+ def test_udf_in_left_outer_join_condition(self):
+ # regression test for SPARK-26147
+ from pyspark.sql.functions import udf, col
+ left = self.spark.createDataFrame([Row(a=1)])
+ right = self.spark.createDataFrame([Row(b=1)])
+ f = udf(lambda a: str(a), StringType())
+ # The join condition can't be pushed down, as it refers to attributes from both sides.
+ # The Python UDF only refer to attributes from one side, so it's evaluable.
+ df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
+ with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+ self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
+ def test_udf_in_left_semi_join_condition(self):
+ # regression test for SPARK-25314
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
+ right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
+ f = udf(lambda a, b: a == b, BooleanType())
+ df = left.join(right, f("a", "b"), "leftsemi")
+ with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
+ df.collect()
+ with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+ self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
+
+ def test_udf_and_common_filter_in_join_condition(self):
+ # regression test for SPARK-25314
+ # test the complex scenario with both udf and common filter
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
+ right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
+ f = udf(lambda a, b: a == b, BooleanType())
+ df = left.join(right, [f("a", "b"), left.a1 == right.b1])
+ # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
+ self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
+
+ def test_udf_and_common_filter_in_left_semi_join_condition(self):
+ # regression test for SPARK-25314
+ # test the complex scenario with both udf and common filter
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
+ right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
+ f = udf(lambda a, b: a == b, BooleanType())
+ df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi")
+ # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
+ self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
+
+ def test_udf_not_supported_in_join_condition(self):
+ # regression test for SPARK-25314
+ # test python udf is not supported in join type besides left_semi and inner join.
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
+ right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
+ f = udf(lambda a, b: a == b, BooleanType())
+
+ def runWithJoinType(join_type, type_string):
+ with self.assertRaisesRegexp(
+ AnalysisException,
+ 'Using PythonUDF.*%s is not supported.' % type_string):
+ left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect()
+ runWithJoinType("full", "FullOuter")
+ runWithJoinType("left", "LeftOuter")
+ runWithJoinType("right", "RightOuter")
+ runWithJoinType("leftanti", "LeftAnti")
+
+ def test_udf_without_arguments(self):
+ self.spark.catalog.registerFunction("foo", lambda: "bar")
+ [row] = self.spark.sql("SELECT foo()").collect()
+ self.assertEqual(row[0], "bar")
+
+ def test_udf_with_array_type(self):
+ with self.tempView("test"):
+ d = [Row(l=list(range(3)), d={"key": list(range(5))})]
+ rdd = self.sc.parallelize(d)
+ self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
+ self.spark.catalog.registerFunction(
+ "copylist", lambda l: list(l), ArrayType(IntegerType()))
+ self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType())
+ [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect()
+ self.assertEqual(list(range(3)), l1)
+ self.assertEqual(1, l2)
+
+ def test_broadcast_in_udf(self):
+ bar = {"a": "aa", "b": "bb", "c": "abc"}
+ foo = self.sc.broadcast(bar)
+ self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+ [res] = self.spark.sql("SELECT MYUDF('c')").collect()
+ self.assertEqual("abc", res[0])
+ [res] = self.spark.sql("SELECT MYUDF('')").collect()
+ self.assertEqual("", res[0])
+
+ def test_udf_with_filter_function(self):
+ df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ from pyspark.sql.functions import udf, col
+ from pyspark.sql.types import BooleanType
+
+ my_filter = udf(lambda a: a < 2, BooleanType())
+ sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
+ self.assertEqual(sel.collect(), [Row(key=1, value='1')])
+
+ def test_udf_with_aggregate_function(self):
+ df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ from pyspark.sql.functions import udf, col, sum
+ from pyspark.sql.types import BooleanType
+
+ my_filter = udf(lambda a: a == 1, BooleanType())
+ sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
+ self.assertEqual(sel.collect(), [Row(key=1)])
+
+ my_copy = udf(lambda x: x, IntegerType())
+ my_add = udf(lambda a, b: int(a + b), IntegerType())
+ my_strlen = udf(lambda x: len(x), IntegerType())
+ sel = df.groupBy(my_copy(col("key")).alias("k"))\
+ .agg(sum(my_strlen(col("value"))).alias("s"))\
+ .select(my_add(col("k"), col("s")).alias("t"))
+ self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+
+ def test_udf_in_generate(self):
+ from pyspark.sql.functions import udf, explode
+ df = self.spark.range(5)
+ f = udf(lambda x: list(range(x)), ArrayType(LongType()))
+ row = df.select(explode(f(*df))).groupBy().sum().first()
+ self.assertEqual(row[0], 10)
+
+ df = self.spark.range(3)
+ res = df.select("id", explode(f(df.id))).collect()
+ self.assertEqual(res[0][0], 1)
+ self.assertEqual(res[0][1], 0)
+ self.assertEqual(res[1][0], 2)
+ self.assertEqual(res[1][1], 0)
+ self.assertEqual(res[2][0], 2)
+ self.assertEqual(res[2][1], 1)
+
+ range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType()))
+ res = df.select("id", explode(range_udf(df.id))).collect()
+ self.assertEqual(res[0][0], 0)
+ self.assertEqual(res[0][1], -1)
+ self.assertEqual(res[1][0], 0)
+ self.assertEqual(res[1][1], 0)
+ self.assertEqual(res[2][0], 1)
+ self.assertEqual(res[2][1], 0)
+ self.assertEqual(res[3][0], 1)
+ self.assertEqual(res[3][1], 1)
+
+ def test_udf_with_order_by_and_limit(self):
+ from pyspark.sql.functions import udf
+ my_copy = udf(lambda x: x, IntegerType())
+ df = self.spark.range(10).orderBy("id")
+ res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
+ res.explain(True)
+ self.assertEqual(res.collect(), [Row(id=0, copy=0)])
+
+ def test_udf_registration_returns_udf(self):
+ df = self.spark.range(10)
+ add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
+
+ self.assertListEqual(
+ df.selectExpr("add_three(id) AS plus_three").collect(),
+ df.select(add_three("id").alias("plus_three")).collect()
+ )
+
+ # This is to check if a 'SQLContext.udf' can call its alias.
+ sqlContext = self.spark._wrapped
+ add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
+
+ self.assertListEqual(
+ df.selectExpr("add_four(id) AS plus_four").collect(),
+ df.select(add_four("id").alias("plus_four")).collect()
+ )
+
+ def test_non_existed_udf(self):
+ spark = self.spark
+ self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
+ lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
+
+ # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
+ sqlContext = spark._wrapped
+ self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
+ lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
+
+ def test_non_existed_udaf(self):
+ spark = self.spark
+ self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
+ lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
+
+ def test_udf_with_input_file_name(self):
+ from pyspark.sql.functions import udf, input_file_name
+ sourceFile = udf(lambda path: path, StringType())
+ filePath = "python/test_support/sql/people1.json"
+ row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
+ self.assertTrue(row[0].find("people1.json") != -1)
+
+ def test_udf_with_input_file_name_for_hadooprdd(self):
+ from pyspark.sql.functions import udf, input_file_name
+
+ def filename(path):
+ return path
+
+ sameText = udf(filename, StringType())
+
+ rdd = self.sc.textFile('python/test_support/sql/people.json')
+ df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
+ row = df.select(sameText(df['file'])).first()
+ self.assertTrue(row[0].find("people.json") != -1)
+
+ rdd2 = self.sc.newAPIHadoopFile(
+ 'python/test_support/sql/people.json',
+ 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
+ 'org.apache.hadoop.io.LongWritable',
+ 'org.apache.hadoop.io.Text')
+
+ df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file'))
+ row2 = df2.select(sameText(df2['file'])).first()
+ self.assertTrue(row2[0].find("people.json") != -1)
+
+ def test_udf_defers_judf_initialization(self):
+ # This is separate of UDFInitializationTests
+ # to avoid context initialization
+ # when udf is called
+
+ from pyspark.sql.functions import UserDefinedFunction
+
+ f = UserDefinedFunction(lambda x: x, StringType())
+
+ self.assertIsNone(
+ f._judf_placeholder,
+ "judf should not be initialized before the first call."
+ )
+
+ self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")
+
+ self.assertIsNotNone(
+ f._judf_placeholder,
+ "judf should be initialized after UDF has been called."
+ )
+
+ def test_udf_with_string_return_type(self):
+ from pyspark.sql.functions import UserDefinedFunction
+
+ add_one = UserDefinedFunction(lambda x: x + 1, "integer")
+ make_pair = UserDefinedFunction(lambda x: (-x, x), "struct")
+ make_array = UserDefinedFunction(
+ lambda x: [float(x) for x in range(x, x + 3)], "array")
+
+ expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
+ actual = (self.spark.range(1, 2).toDF("x")
+ .select(add_one("x"), make_pair("x"), make_array("x"))
+ .first())
+
+ self.assertTupleEqual(expected, actual)
+
+ def test_udf_shouldnt_accept_noncallable_object(self):
+ from pyspark.sql.functions import UserDefinedFunction
+
+ non_callable = None
+ self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
+
+ def test_udf_with_decorator(self):
+ from pyspark.sql.functions import lit, udf
+ from pyspark.sql.types import IntegerType, DoubleType
+
+ @udf(IntegerType())
+ def add_one(x):
+ if x is not None:
+ return x + 1
+
+ @udf(returnType=DoubleType())
+ def add_two(x):
+ if x is not None:
+ return float(x + 2)
+
+ @udf
+ def to_upper(x):
+ if x is not None:
+ return x.upper()
+
+ @udf()
+ def to_lower(x):
+ if x is not None:
+ return x.lower()
+
+ @udf
+ def substr(x, start, end):
+ if x is not None:
+ return x[start:end]
+
+ @udf("long")
+ def trunc(x):
+ return int(x)
+
+ @udf(returnType="double")
+ def as_double(x):
+ return float(x)
+
+ df = (
+ self.spark
+ .createDataFrame(
+ [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
+ .select(
+ add_one("one"), add_two("one"),
+ to_upper("Foo"), to_lower("Foo"),
+ substr("foobar", lit(0), lit(3)),
+ trunc("float"), as_double("one")))
+
+ self.assertListEqual(
+ [tpe for _, tpe in df.dtypes],
+ ["int", "double", "string", "string", "string", "bigint", "double"]
+ )
+
+ self.assertListEqual(
+ list(df.first()),
+ [2, 3.0, "FOO", "foo", "foo", 3, 1.0]
+ )
+
+ def test_udf_wrapper(self):
+ from pyspark.sql.functions import udf
+ from pyspark.sql.types import IntegerType
+
+ def f(x):
+ """Identity"""
+ return x
+
+ return_type = IntegerType()
+ f_ = udf(f, return_type)
+
+ self.assertTrue(f.__doc__ in f_.__doc__)
+ self.assertEqual(f, f_.func)
+ self.assertEqual(return_type, f_.returnType)
+
+ class F(object):
+ """Identity"""
+ def __call__(self, x):
+ return x
+
+ f = F()
+ return_type = IntegerType()
+ f_ = udf(f, return_type)
+
+ self.assertTrue(f.__doc__ in f_.__doc__)
+ self.assertEqual(f, f_.func)
+ self.assertEqual(return_type, f_.returnType)
+
+ f = functools.partial(f, x=1)
+ return_type = IntegerType()
+ f_ = udf(f, return_type)
+
+ self.assertTrue(f.__doc__ in f_.__doc__)
+ self.assertEqual(f, f_.func)
+ self.assertEqual(return_type, f_.returnType)
+
+ def test_nonparam_udf_with_aggregate(self):
+ import pyspark.sql.functions as f
+
+ df = self.spark.createDataFrame([(1, 2), (1, 2)])
+ f_udf = f.udf(lambda: "const_str")
+ rows = df.distinct().withColumn("a", f_udf()).collect()
+ self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])
+
+ # SPARK-24721
+ @unittest.skipIf(not test_compiled, test_not_compiled_message)
+ def test_datasource_with_udf(self):
+ from pyspark.sql.functions import udf, lit, col
+
+ path = tempfile.mkdtemp()
+ shutil.rmtree(path)
+
+ try:
+ self.spark.range(1).write.mode("overwrite").format('csv').save(path)
+ filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
+ datasource_df = self.spark.read \
+ .format("org.apache.spark.sql.sources.SimpleScanSource") \
+ .option('from', 0).option('to', 1).load().toDF('i')
+ datasource_v2_df = self.spark.read \
+ .format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
+ .load().toDF('i', 'j')
+
+ c1 = udf(lambda x: x + 1, 'int')(lit(1))
+ c2 = udf(lambda x: x + 1, 'int')(col('i'))
+
+ f1 = udf(lambda x: False, 'boolean')(lit(1))
+ f2 = udf(lambda x: False, 'boolean')(col('i'))
+
+ for df in [filesource_df, datasource_df, datasource_v2_df]:
+ result = df.withColumn('c', c1)
+ expected = df.withColumn('c', lit(2))
+ self.assertEquals(expected.collect(), result.collect())
+
+ for df in [filesource_df, datasource_df, datasource_v2_df]:
+ result = df.withColumn('c', c2)
+ expected = df.withColumn('c', col('i') + 1)
+ self.assertEquals(expected.collect(), result.collect())
+
+ for df in [filesource_df, datasource_df, datasource_v2_df]:
+ for f in [f1, f2]:
+ result = df.filter(f)
+ self.assertEquals(0, result.count())
+ finally:
+ shutil.rmtree(path)
+
+ # SPARK-25591
+ def test_same_accumulator_in_udfs(self):
+ from pyspark.sql.functions import udf
+
+ data_schema = StructType([StructField("a", IntegerType(), True),
+ StructField("b", IntegerType(), True)])
+ data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
+
+ test_accum = self.sc.accumulator(0)
+
+ def first_udf(x):
+ test_accum.add(1)
+ return x
+
+ def second_udf(x):
+ test_accum.add(100)
+ return x
+
+ func_udf = udf(first_udf, IntegerType())
+ func_udf2 = udf(second_udf, IntegerType())
+ data = data.withColumn("out1", func_udf(data["a"]))
+ data = data.withColumn("out2", func_udf2(data["b"]))
+ data.collect()
+ self.assertEqual(test_accum.value, 101)
+
+
+class UDFInitializationTests(unittest.TestCase):
+ def tearDown(self):
+ if SparkSession._instantiatedSession is not None:
+ SparkSession._instantiatedSession.stop()
+
+ if SparkContext._active_spark_context is not None:
+ SparkContext._active_spark_context.stop()
+
+ def test_udf_init_shouldnt_initialize_context(self):
+ from pyspark.sql.functions import UserDefinedFunction
+
+ UserDefinedFunction(lambda x: x, StringType())
+
+ self.assertIsNone(
+ SparkContext._active_spark_context,
+ "SparkContext shouldn't be initialized when UserDefinedFunction is created."
+ )
+ self.assertIsNone(
+ SparkSession._instantiatedSession,
+ "SparkSession shouldn't be initialized when UserDefinedFunction is created."
+ )
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_udf import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py
new file mode 100644
index 0000000000000..5bb921da5c2f3
--- /dev/null
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -0,0 +1,55 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql.functions import sha2
+from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class UtilsTests(ReusedSQLTestCase):
+
+ def test_capture_analysis_exception(self):
+ self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
+ self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
+
+ def test_capture_parse_exception(self):
+ self.assertRaises(ParseException, lambda: self.spark.sql("abc"))
+
+ def test_capture_illegalargument_exception(self):
+ self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
+ lambda: self.spark.sql("SET mapred.reduce.tasks=-1"))
+ df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
+ self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
+ lambda: df.select(sha2(df.a, 1024)).collect())
+ try:
+ df.select(sha2(df.a, 1024)).collect()
+ except IllegalArgumentException as e:
+ self.assertRegexpMatches(e.desc, "1024 is not in the permitted values")
+ self.assertRegexpMatches(e.stackTrace,
+ "org.apache.spark.sql.functions")
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.test_utils import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/streaming/tests/__init__.py b/python/pyspark/streaming/tests/__init__.py
new file mode 100644
index 0000000000000..cce3acad34a49
--- /dev/null
+++ b/python/pyspark/streaming/tests/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/streaming/tests/test_context.py b/python/pyspark/streaming/tests/test_context.py
new file mode 100644
index 0000000000000..b44121462a920
--- /dev/null
+++ b/python/pyspark/streaming/tests/test_context.py
@@ -0,0 +1,184 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import struct
+import tempfile
+import time
+
+from pyspark.streaming import StreamingContext
+from pyspark.testing.streamingutils import PySparkStreamingTestCase
+
+
+class StreamingContextTests(PySparkStreamingTestCase):
+
+ duration = 0.1
+ setupCalled = False
+
+ def _add_input_stream(self):
+ inputs = [range(1, x) for x in range(101)]
+ stream = self.ssc.queueStream(inputs)
+ self._collect(stream, 1, block=False)
+
+ def test_stop_only_streaming_context(self):
+ self._add_input_stream()
+ self.ssc.start()
+ self.ssc.stop(False)
+ self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)
+
+ def test_stop_multiple_times(self):
+ self._add_input_stream()
+ self.ssc.start()
+ self.ssc.stop(False)
+ self.ssc.stop(False)
+
+ def test_queue_stream(self):
+ input = [list(range(i + 1)) for i in range(3)]
+ dstream = self.ssc.queueStream(input)
+ result = self._collect(dstream, 3)
+ self.assertEqual(input, result)
+
+ def test_text_file_stream(self):
+ d = tempfile.mkdtemp()
+ self.ssc = StreamingContext(self.sc, self.duration)
+ dstream2 = self.ssc.textFileStream(d).map(int)
+ result = self._collect(dstream2, 2, block=False)
+ self.ssc.start()
+ for name in ('a', 'b'):
+ time.sleep(1)
+ with open(os.path.join(d, name), "w") as f:
+ f.writelines(["%d\n" % i for i in range(10)])
+ self.wait_for(result, 2)
+ self.assertEqual([list(range(10)), list(range(10))], result)
+
+ def test_binary_records_stream(self):
+ d = tempfile.mkdtemp()
+ self.ssc = StreamingContext(self.sc, self.duration)
+ dstream = self.ssc.binaryRecordsStream(d, 10).map(
+ lambda v: struct.unpack("10b", bytes(v)))
+ result = self._collect(dstream, 2, block=False)
+ self.ssc.start()
+ for name in ('a', 'b'):
+ time.sleep(1)
+ with open(os.path.join(d, name), "wb") as f:
+ f.write(bytearray(range(10)))
+ self.wait_for(result, 2)
+ self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result])
+
+ def test_union(self):
+ input = [list(range(i + 1)) for i in range(3)]
+ dstream = self.ssc.queueStream(input)
+ dstream2 = self.ssc.queueStream(input)
+ dstream3 = self.ssc.union(dstream, dstream2)
+ result = self._collect(dstream3, 3)
+ expected = [i * 2 for i in input]
+ self.assertEqual(expected, result)
+
+ def test_transform(self):
+ dstream1 = self.ssc.queueStream([[1]])
+ dstream2 = self.ssc.queueStream([[2]])
+ dstream3 = self.ssc.queueStream([[3]])
+
+ def func(rdds):
+ rdd1, rdd2, rdd3 = rdds
+ return rdd2.union(rdd3).union(rdd1)
+
+ dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)
+
+ self.assertEqual([2, 3, 1], self._take(dstream, 3))
+
+ def test_transform_pairrdd(self):
+ # This regression test case is for SPARK-17756.
+ dstream = self.ssc.queueStream(
+ [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd))
+ self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3))
+
+ def test_get_active(self):
+ self.assertEqual(StreamingContext.getActive(), None)
+
+ # Verify that getActive() returns the active context
+ self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
+ self.ssc.start()
+ self.assertEqual(StreamingContext.getActive(), self.ssc)
+
+ # Verify that getActive() returns None
+ self.ssc.stop(False)
+ self.assertEqual(StreamingContext.getActive(), None)
+
+ # Verify that if the Java context is stopped, then getActive() returns None
+ self.ssc = StreamingContext(self.sc, self.duration)
+ self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
+ self.ssc.start()
+ self.assertEqual(StreamingContext.getActive(), self.ssc)
+ self.ssc._jssc.stop(False)
+ self.assertEqual(StreamingContext.getActive(), None)
+
+ def test_get_active_or_create(self):
+ # Test StreamingContext.getActiveOrCreate() without checkpoint data
+ # See CheckpointTests for tests with checkpoint data
+ self.ssc = None
+ self.assertEqual(StreamingContext.getActive(), None)
+
+ def setupFunc():
+ ssc = StreamingContext(self.sc, self.duration)
+ ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
+ self.setupCalled = True
+ return ssc
+
+ # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active
+ self.setupCalled = False
+ self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
+ self.assertTrue(self.setupCalled)
+
+ # Verify that getActiveOrCreate() returns active context and does not call the setupFunc
+ self.ssc.start()
+ self.setupCalled = False
+ self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc)
+ self.assertFalse(self.setupCalled)
+
+ # Verify that getActiveOrCreate() calls setupFunc after active context is stopped
+ self.ssc.stop(False)
+ self.setupCalled = False
+ self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
+ self.assertTrue(self.setupCalled)
+
+ # Verify that if the Java context is stopped, then getActive() returns None
+ self.ssc = StreamingContext(self.sc, self.duration)
+ self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
+ self.ssc.start()
+ self.assertEqual(StreamingContext.getActive(), self.ssc)
+ self.ssc._jssc.stop(False)
+ self.setupCalled = False
+ self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
+ self.assertTrue(self.setupCalled)
+
+ def test_await_termination_or_timeout(self):
+ self._add_input_stream()
+ self.ssc.start()
+ self.assertFalse(self.ssc.awaitTerminationOrTimeout(0.001))
+ self.ssc.stop(False)
+ self.assertTrue(self.ssc.awaitTerminationOrTimeout(0.001))
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.streaming.tests.test_context import *
+
+ try:
+ import xmlrunner
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
+ except ImportError:
+ unittest.main(verbosity=2)
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests/test_dstream.py
similarity index 50%
rename from python/pyspark/streaming/tests.py
rename to python/pyspark/streaming/tests/test_dstream.py
index 34e3291651eec..d14e346b7a688 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests/test_dstream.py
@@ -14,151 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
-import glob
-import os
-import sys
-from itertools import chain
-import time
import operator
-import tempfile
-import random
-import struct
+import os
import shutil
-import unishark
+import tempfile
+import time
+import unittest
from functools import reduce
+from itertools import chain
-if sys.version_info[:2] <= (2, 6):
- try:
- import unittest2 as unittest
- except ImportError:
- sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
- sys.exit(1)
-else:
- import unittest
-
-if sys.version >= "3":
- long = int
-
-from pyspark.context import SparkConf, SparkContext, RDD
-from pyspark.storagelevel import StorageLevel
-from pyspark.streaming.context import StreamingContext
-from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream
-from pyspark.streaming.listener import StreamingListener
-
-
-class PySparkStreamingTestCase(unittest.TestCase):
-
- timeout = 30 # seconds
- duration = .5
-
- @classmethod
- def setUpClass(cls):
- class_name = cls.__name__
- conf = SparkConf().set("spark.default.parallelism", 1)
- cls.sc = SparkContext(appName=class_name, conf=conf)
- cls.sc.setCheckpointDir(tempfile.mkdtemp())
-
- @classmethod
- def tearDownClass(cls):
- cls.sc.stop()
- # Clean up in the JVM just in case there has been some issues in Python API
- try:
- jSparkContextOption = SparkContext._jvm.SparkContext.get()
- if jSparkContextOption.nonEmpty():
- jSparkContextOption.get().stop()
- except:
- pass
-
- def setUp(self):
- self.ssc = StreamingContext(self.sc, self.duration)
-
- def tearDown(self):
- if self.ssc is not None:
- self.ssc.stop(False)
- # Clean up in the JVM just in case there has been some issues in Python API
- try:
- jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
- if jStreamingContextOption.nonEmpty():
- jStreamingContextOption.get().stop(False)
- except:
- pass
-
- def wait_for(self, result, n):
- start_time = time.time()
- while len(result) < n and time.time() - start_time < self.timeout:
- time.sleep(0.01)
- if len(result) < n:
- print("timeout after", self.timeout)
-
- def _take(self, dstream, n):
- """
- Return the first `n` elements in the stream (will start and stop).
- """
- results = []
-
- def take(_, rdd):
- if rdd and len(results) < n:
- results.extend(rdd.take(n - len(results)))
-
- dstream.foreachRDD(take)
-
- self.ssc.start()
- self.wait_for(results, n)
- return results
-
- def _collect(self, dstream, n, block=True):
- """
- Collect each RDDs into the returned list.
-
- :return: list, which will have the collected items.
- """
- result = []
-
- def get_output(_, rdd):
- if rdd and len(result) < n:
- r = rdd.collect()
- if r:
- result.append(r)
-
- dstream.foreachRDD(get_output)
-
- if not block:
- return result
-
- self.ssc.start()
- self.wait_for(result, n)
- return result
-
- def _test_func(self, input, func, expected, sort=False, input2=None):
- """
- @param input: dataset for the test. This should be list of lists.
- @param func: wrapped function. This function should return PythonDStream object.
- @param expected: expected output for this testcase.
- """
- if not isinstance(input[0], RDD):
- input = [self.sc.parallelize(d, 1) for d in input]
- input_stream = self.ssc.queueStream(input)
- if input2 and not isinstance(input2[0], RDD):
- input2 = [self.sc.parallelize(d, 1) for d in input2]
- input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
-
- # Apply test function to stream.
- if input2:
- stream = func(input_stream, input_stream2)
- else:
- stream = func(input_stream)
-
- result = self._collect(stream, len(expected))
- if sort:
- self._sort_result_based_on_key(result)
- self._sort_result_based_on_key(expected)
- self.assertEqual(expected, result)
-
- def _sort_result_based_on_key(self, outputs):
- """Sort the list based on first value."""
- for output in outputs:
- output.sort(key=lambda x: x[0])
+from pyspark import SparkConf, SparkContext, RDD
+from pyspark.streaming import StreamingContext
+from pyspark.testing.streamingutils import PySparkStreamingTestCase
class BasicOperationTests(PySparkStreamingTestCase):
@@ -522,135 +389,6 @@ def failed_func(i):
self.fail("a failed func should throw an error")
-class StreamingListenerTests(PySparkStreamingTestCase):
-
- duration = .5
-
- class BatchInfoCollector(StreamingListener):
-
- def __init__(self):
- super(StreamingListener, self).__init__()
- self.batchInfosCompleted = []
- self.batchInfosStarted = []
- self.batchInfosSubmitted = []
- self.streamingStartedTime = []
-
- def onStreamingStarted(self, streamingStarted):
- self.streamingStartedTime.append(streamingStarted.time)
-
- def onBatchSubmitted(self, batchSubmitted):
- self.batchInfosSubmitted.append(batchSubmitted.batchInfo())
-
- def onBatchStarted(self, batchStarted):
- self.batchInfosStarted.append(batchStarted.batchInfo())
-
- def onBatchCompleted(self, batchCompleted):
- self.batchInfosCompleted.append(batchCompleted.batchInfo())
-
- def test_batch_info_reports(self):
- batch_collector = self.BatchInfoCollector()
- self.ssc.addStreamingListener(batch_collector)
- input = [[1], [2], [3], [4]]
-
- def func(dstream):
- return dstream.map(int)
- expected = [[1], [2], [3], [4]]
- self._test_func(input, func, expected)
-
- batchInfosSubmitted = batch_collector.batchInfosSubmitted
- batchInfosStarted = batch_collector.batchInfosStarted
- batchInfosCompleted = batch_collector.batchInfosCompleted
- streamingStartedTime = batch_collector.streamingStartedTime
-
- self.wait_for(batchInfosCompleted, 4)
-
- self.assertEqual(len(streamingStartedTime), 1)
-
- self.assertGreaterEqual(len(batchInfosSubmitted), 4)
- for info in batchInfosSubmitted:
- self.assertGreaterEqual(info.batchTime().milliseconds(), 0)
- self.assertGreaterEqual(info.submissionTime(), 0)
-
- for streamId in info.streamIdToInputInfo():
- streamInputInfo = info.streamIdToInputInfo()[streamId]
- self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0)
- self.assertGreaterEqual(streamInputInfo.numRecords, 0)
- for key in streamInputInfo.metadata():
- self.assertIsNotNone(streamInputInfo.metadata()[key])
- self.assertIsNotNone(streamInputInfo.metadataDescription())
-
- for outputOpId in info.outputOperationInfos():
- outputInfo = info.outputOperationInfos()[outputOpId]
- self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0)
- self.assertGreaterEqual(outputInfo.id(), 0)
- self.assertIsNotNone(outputInfo.name())
- self.assertIsNotNone(outputInfo.description())
- self.assertGreaterEqual(outputInfo.startTime(), -1)
- self.assertGreaterEqual(outputInfo.endTime(), -1)
- self.assertIsNone(outputInfo.failureReason())
-
- self.assertEqual(info.schedulingDelay(), -1)
- self.assertEqual(info.processingDelay(), -1)
- self.assertEqual(info.totalDelay(), -1)
- self.assertEqual(info.numRecords(), 0)
-
- self.assertGreaterEqual(len(batchInfosStarted), 4)
- for info in batchInfosStarted:
- self.assertGreaterEqual(info.batchTime().milliseconds(), 0)
- self.assertGreaterEqual(info.submissionTime(), 0)
-
- for streamId in info.streamIdToInputInfo():
- streamInputInfo = info.streamIdToInputInfo()[streamId]
- self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0)
- self.assertGreaterEqual(streamInputInfo.numRecords, 0)
- for key in streamInputInfo.metadata():
- self.assertIsNotNone(streamInputInfo.metadata()[key])
- self.assertIsNotNone(streamInputInfo.metadataDescription())
-
- for outputOpId in info.outputOperationInfos():
- outputInfo = info.outputOperationInfos()[outputOpId]
- self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0)
- self.assertGreaterEqual(outputInfo.id(), 0)
- self.assertIsNotNone(outputInfo.name())
- self.assertIsNotNone(outputInfo.description())
- self.assertGreaterEqual(outputInfo.startTime(), -1)
- self.assertGreaterEqual(outputInfo.endTime(), -1)
- self.assertIsNone(outputInfo.failureReason())
-
- self.assertGreaterEqual(info.schedulingDelay(), 0)
- self.assertEqual(info.processingDelay(), -1)
- self.assertEqual(info.totalDelay(), -1)
- self.assertEqual(info.numRecords(), 0)
-
- self.assertGreaterEqual(len(batchInfosCompleted), 4)
- for info in batchInfosCompleted:
- self.assertGreaterEqual(info.batchTime().milliseconds(), 0)
- self.assertGreaterEqual(info.submissionTime(), 0)
-
- for streamId in info.streamIdToInputInfo():
- streamInputInfo = info.streamIdToInputInfo()[streamId]
- self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0)
- self.assertGreaterEqual(streamInputInfo.numRecords, 0)
- for key in streamInputInfo.metadata():
- self.assertIsNotNone(streamInputInfo.metadata()[key])
- self.assertIsNotNone(streamInputInfo.metadataDescription())
-
- for outputOpId in info.outputOperationInfos():
- outputInfo = info.outputOperationInfos()[outputOpId]
- self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0)
- self.assertGreaterEqual(outputInfo.id(), 0)
- self.assertIsNotNone(outputInfo.name())
- self.assertIsNotNone(outputInfo.description())
- self.assertGreaterEqual(outputInfo.startTime(), 0)
- self.assertGreaterEqual(outputInfo.endTime(), 0)
- self.assertIsNone(outputInfo.failureReason())
-
- self.assertGreaterEqual(info.schedulingDelay(), 0)
- self.assertGreaterEqual(info.processingDelay(), 0)
- self.assertGreaterEqual(info.totalDelay(), 0)
- self.assertEqual(info.numRecords(), 0)
-
-
class WindowFunctionTests(PySparkStreamingTestCase):
timeout = 15
@@ -728,156 +466,6 @@ def func(dstream):
self._test_func(input, func, expected)
-class StreamingContextTests(PySparkStreamingTestCase):
-
- duration = 0.1
- setupCalled = False
-
- def _add_input_stream(self):
- inputs = [range(1, x) for x in range(101)]
- stream = self.ssc.queueStream(inputs)
- self._collect(stream, 1, block=False)
-
- def test_stop_only_streaming_context(self):
- self._add_input_stream()
- self.ssc.start()
- self.ssc.stop(False)
- self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)
-
- def test_stop_multiple_times(self):
- self._add_input_stream()
- self.ssc.start()
- self.ssc.stop(False)
- self.ssc.stop(False)
-
- def test_queue_stream(self):
- input = [list(range(i + 1)) for i in range(3)]
- dstream = self.ssc.queueStream(input)
- result = self._collect(dstream, 3)
- self.assertEqual(input, result)
-
- def test_text_file_stream(self):
- d = tempfile.mkdtemp()
- self.ssc = StreamingContext(self.sc, self.duration)
- dstream2 = self.ssc.textFileStream(d).map(int)
- result = self._collect(dstream2, 2, block=False)
- self.ssc.start()
- for name in ('a', 'b'):
- time.sleep(1)
- with open(os.path.join(d, name), "w") as f:
- f.writelines(["%d\n" % i for i in range(10)])
- self.wait_for(result, 2)
- self.assertEqual([list(range(10)), list(range(10))], result)
-
- def test_binary_records_stream(self):
- d = tempfile.mkdtemp()
- self.ssc = StreamingContext(self.sc, self.duration)
- dstream = self.ssc.binaryRecordsStream(d, 10).map(
- lambda v: struct.unpack("10b", bytes(v)))
- result = self._collect(dstream, 2, block=False)
- self.ssc.start()
- for name in ('a', 'b'):
- time.sleep(1)
- with open(os.path.join(d, name), "wb") as f:
- f.write(bytearray(range(10)))
- self.wait_for(result, 2)
- self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result])
-
- def test_union(self):
- input = [list(range(i + 1)) for i in range(3)]
- dstream = self.ssc.queueStream(input)
- dstream2 = self.ssc.queueStream(input)
- dstream3 = self.ssc.union(dstream, dstream2)
- result = self._collect(dstream3, 3)
- expected = [i * 2 for i in input]
- self.assertEqual(expected, result)
-
- def test_transform(self):
- dstream1 = self.ssc.queueStream([[1]])
- dstream2 = self.ssc.queueStream([[2]])
- dstream3 = self.ssc.queueStream([[3]])
-
- def func(rdds):
- rdd1, rdd2, rdd3 = rdds
- return rdd2.union(rdd3).union(rdd1)
-
- dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)
-
- self.assertEqual([2, 3, 1], self._take(dstream, 3))
-
- def test_transform_pairrdd(self):
- # This regression test case is for SPARK-17756.
- dstream = self.ssc.queueStream(
- [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd))
- self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3))
-
- def test_get_active(self):
- self.assertEqual(StreamingContext.getActive(), None)
-
- # Verify that getActive() returns the active context
- self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
- self.ssc.start()
- self.assertEqual(StreamingContext.getActive(), self.ssc)
-
- # Verify that getActive() returns None
- self.ssc.stop(False)
- self.assertEqual(StreamingContext.getActive(), None)
-
- # Verify that if the Java context is stopped, then getActive() returns None
- self.ssc = StreamingContext(self.sc, self.duration)
- self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
- self.ssc.start()
- self.assertEqual(StreamingContext.getActive(), self.ssc)
- self.ssc._jssc.stop(False)
- self.assertEqual(StreamingContext.getActive(), None)
-
- def test_get_active_or_create(self):
- # Test StreamingContext.getActiveOrCreate() without checkpoint data
- # See CheckpointTests for tests with checkpoint data
- self.ssc = None
- self.assertEqual(StreamingContext.getActive(), None)
-
- def setupFunc():
- ssc = StreamingContext(self.sc, self.duration)
- ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
- self.setupCalled = True
- return ssc
-
- # Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active
- self.setupCalled = False
- self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
- self.assertTrue(self.setupCalled)
-
- # Verify that getActiveOrCreate() returns active context and does not call the setupFunc
- self.ssc.start()
- self.setupCalled = False
- self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc)
- self.assertFalse(self.setupCalled)
-
- # Verify that getActiveOrCreate() calls setupFunc after active context is stopped
- self.ssc.stop(False)
- self.setupCalled = False
- self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
- self.assertTrue(self.setupCalled)
-
- # Verify that if the Java context is stopped, then getActive() returns None
- self.ssc = StreamingContext(self.sc, self.duration)
- self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
- self.ssc.start()
- self.assertEqual(StreamingContext.getActive(), self.ssc)
- self.ssc._jssc.stop(False)
- self.setupCalled = False
- self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
- self.assertTrue(self.setupCalled)
-
- def test_await_termination_or_timeout(self):
- self._add_input_stream()
- self.ssc.start()
- self.assertFalse(self.ssc.awaitTerminationOrTimeout(0.001))
- self.ssc.stop(False)
- self.assertTrue(self.ssc.awaitTerminationOrTimeout(0.001))
-
-
class CheckpointTests(unittest.TestCase):
setupCalled = False
@@ -1042,140 +630,11 @@ def check_output(n):
self.ssc.stop(True, True)
-class KinesisStreamTests(PySparkStreamingTestCase):
-
- def test_kinesis_stream_api(self):
- # Don't start the StreamingContext because we cannot test it in Jenkins
- kinesisStream1 = KinesisUtils.createStream(
- self.ssc, "myAppNam", "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
- InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2)
- kinesisStream2 = KinesisUtils.createStream(
- self.ssc, "myAppNam", "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
- InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2,
- "awsAccessKey", "awsSecretKey")
-
- def test_kinesis_stream(self):
- if not are_kinesis_tests_enabled:
- sys.stderr.write(
- "Skipped test_kinesis_stream (enable by setting environment variable %s=1"
- % kinesis_test_environ_var)
- return
-
- import random
- kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000)))
- kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2)
- try:
- kinesisTestUtils.createStream()
- aWSCredentials = kinesisTestUtils.getAWSCredentials()
- stream = KinesisUtils.createStream(
- self.ssc, kinesisAppName, kinesisTestUtils.streamName(),
- kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(),
- InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY,
- aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey())
-
- outputBuffer = []
-
- def get_output(_, rdd):
- for e in rdd.collect():
- outputBuffer.append(e)
-
- stream.foreachRDD(get_output)
- self.ssc.start()
-
- testData = [i for i in range(1, 11)]
- expectedOutput = set([str(i) for i in testData])
- start_time = time.time()
- while time.time() - start_time < 120:
- kinesisTestUtils.pushData(testData)
- if expectedOutput == set(outputBuffer):
- break
- time.sleep(10)
- self.assertEqual(expectedOutput, set(outputBuffer))
- except:
- import traceback
- traceback.print_exc()
- raise
- finally:
- self.ssc.stop(False)
- kinesisTestUtils.deleteStream()
- kinesisTestUtils.deleteDynamoDBTable(kinesisAppName)
-
-
-# Search jar in the project dir using the jar name_prefix for both sbt build and maven build because
-# the artifact jars are in different directories.
-def search_jar(dir, name_prefix):
- # We should ignore the following jars
- ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar")
- jars = (glob.glob(os.path.join(dir, "target/scala-*/" + name_prefix + "-*.jar")) + # sbt build
- glob.glob(os.path.join(dir, "target/" + name_prefix + "_*.jar"))) # maven build
- return [jar for jar in jars if not jar.endswith(ignored_jar_suffixes)]
-
-
-def _kinesis_asl_assembly_dir():
- SPARK_HOME = os.environ["SPARK_HOME"]
- return os.path.join(SPARK_HOME, "external/kinesis-asl-assembly")
-
-
-def search_kinesis_asl_assembly_jar():
- jars = search_jar(_kinesis_asl_assembly_dir(), "spark-streaming-kinesis-asl-assembly")
- if not jars:
- return None
- elif len(jars) > 1:
- raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please "
- "remove all but one") % (", ".join(jars)))
- else:
- return jars[0]
-
-
-# Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py
-kinesis_test_environ_var = "ENABLE_KINESIS_TESTS"
-are_kinesis_tests_enabled = os.environ.get(kinesis_test_environ_var) == '1'
-
if __name__ == "__main__":
- from pyspark.streaming.tests import *
- kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar()
-
- if kinesis_asl_assembly_jar is None:
- kinesis_jar_present = False
- jars_args = ""
- else:
- kinesis_jar_present = True
- jars_args = "--jars %s" % kinesis_asl_assembly_jar
-
- existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
- os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args])
- testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests,
- StreamingListenerTests]
-
- if kinesis_jar_present is True:
- testcases.append(KinesisStreamTests)
- elif are_kinesis_tests_enabled is False:
- sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was "
- "not compiled into a JAR. To run these tests, "
- "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package "
- "streaming-kinesis-asl-assembly/assembly' or "
- "'build/mvn -Pkinesis-asl package' before running this test.")
- else:
- raise Exception(
- ("Failed to find Spark Streaming Kinesis assembly jar in %s. "
- % _kinesis_asl_assembly_dir()) +
- "You need to build Spark with 'build/sbt -Pkinesis-asl "
- "assembly/package streaming-kinesis-asl-assembly/assembly'"
- "or 'build/mvn -Pkinesis-asl package' before running this test.")
-
- sys.stderr.write("Running tests: %s \n" % (str(testcases)))
- failed = False
- for testcase in testcases:
- sys.stderr.write("[Running %s]\n" % (testcase))
- tests = unittest.TestLoader().loadTestsFromTestCase(testcase)
- runner = unishark.BufferedTestRunner(
- verbosity=2,
- reporters=[unishark.XUnitReporter('target/test-reports/pyspark.streaming_{}'.format(
- os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))])
-
- result = runner.run(tests)
- if not result.wasSuccessful():
- failed = True
- sys.exit(failed)
+ from pyspark.streaming.tests.test_dstream import *
+
+ try:
+ import xmlrunner
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
+ except ImportError:
+ unittest.main(verbosity=2)
diff --git a/python/pyspark/streaming/tests/test_kinesis.py b/python/pyspark/streaming/tests/test_kinesis.py
new file mode 100644
index 0000000000000..d8a0b47f04097
--- /dev/null
+++ b/python/pyspark/streaming/tests/test_kinesis.py
@@ -0,0 +1,89 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import time
+import unittest
+
+from pyspark import StorageLevel
+from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream
+from pyspark.testing.streamingutils import should_test_kinesis, kinesis_requirement_message, \
+ PySparkStreamingTestCase
+
+
+@unittest.skipIf(not should_test_kinesis, kinesis_requirement_message)
+class KinesisStreamTests(PySparkStreamingTestCase):
+
+ def test_kinesis_stream_api(self):
+ # Don't start the StreamingContext because we cannot test it in Jenkins
+ KinesisUtils.createStream(
+ self.ssc, "myAppNam", "mySparkStream",
+ "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
+ InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2)
+ KinesisUtils.createStream(
+ self.ssc, "myAppNam", "mySparkStream",
+ "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
+ InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2,
+ "awsAccessKey", "awsSecretKey")
+
+ def test_kinesis_stream(self):
+ import random
+ kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000)))
+ kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2)
+ try:
+ kinesisTestUtils.createStream()
+ aWSCredentials = kinesisTestUtils.getAWSCredentials()
+ stream = KinesisUtils.createStream(
+ self.ssc, kinesisAppName, kinesisTestUtils.streamName(),
+ kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(),
+ InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY,
+ aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey())
+
+ outputBuffer = []
+
+ def get_output(_, rdd):
+ for e in rdd.collect():
+ outputBuffer.append(e)
+
+ stream.foreachRDD(get_output)
+ self.ssc.start()
+
+ testData = [i for i in range(1, 11)]
+ expectedOutput = set([str(i) for i in testData])
+ start_time = time.time()
+ while time.time() - start_time < 120:
+ kinesisTestUtils.pushData(testData)
+ if expectedOutput == set(outputBuffer):
+ break
+ time.sleep(10)
+ self.assertEqual(expectedOutput, set(outputBuffer))
+ except:
+ import traceback
+ traceback.print_exc()
+ raise
+ finally:
+ self.ssc.stop(False)
+ kinesisTestUtils.deleteStream()
+ kinesisTestUtils.deleteDynamoDBTable(kinesisAppName)
+
+
+if __name__ == "__main__":
+ from pyspark.streaming.tests.test_kinesis import *
+
+ try:
+ import xmlrunner
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
+ except ImportError:
+ unittest.main(verbosity=2)
diff --git a/python/pyspark/streaming/tests/test_listener.py b/python/pyspark/streaming/tests/test_listener.py
new file mode 100644
index 0000000000000..7c874b6b32500
--- /dev/null
+++ b/python/pyspark/streaming/tests/test_listener.py
@@ -0,0 +1,158 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.streaming import StreamingListener
+from pyspark.testing.streamingutils import PySparkStreamingTestCase
+
+
+class StreamingListenerTests(PySparkStreamingTestCase):
+
+ duration = .5
+
+ class BatchInfoCollector(StreamingListener):
+
+ def __init__(self):
+ super(StreamingListener, self).__init__()
+ self.batchInfosCompleted = []
+ self.batchInfosStarted = []
+ self.batchInfosSubmitted = []
+ self.streamingStartedTime = []
+
+ def onStreamingStarted(self, streamingStarted):
+ self.streamingStartedTime.append(streamingStarted.time)
+
+ def onBatchSubmitted(self, batchSubmitted):
+ self.batchInfosSubmitted.append(batchSubmitted.batchInfo())
+
+ def onBatchStarted(self, batchStarted):
+ self.batchInfosStarted.append(batchStarted.batchInfo())
+
+ def onBatchCompleted(self, batchCompleted):
+ self.batchInfosCompleted.append(batchCompleted.batchInfo())
+
+ def test_batch_info_reports(self):
+ batch_collector = self.BatchInfoCollector()
+ self.ssc.addStreamingListener(batch_collector)
+ input = [[1], [2], [3], [4]]
+
+ def func(dstream):
+ return dstream.map(int)
+ expected = [[1], [2], [3], [4]]
+ self._test_func(input, func, expected)
+
+ batchInfosSubmitted = batch_collector.batchInfosSubmitted
+ batchInfosStarted = batch_collector.batchInfosStarted
+ batchInfosCompleted = batch_collector.batchInfosCompleted
+ streamingStartedTime = batch_collector.streamingStartedTime
+
+ self.wait_for(batchInfosCompleted, 4)
+
+ self.assertEqual(len(streamingStartedTime), 1)
+
+ self.assertGreaterEqual(len(batchInfosSubmitted), 4)
+ for info in batchInfosSubmitted:
+ self.assertGreaterEqual(info.batchTime().milliseconds(), 0)
+ self.assertGreaterEqual(info.submissionTime(), 0)
+
+ for streamId in info.streamIdToInputInfo():
+ streamInputInfo = info.streamIdToInputInfo()[streamId]
+ self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0)
+ self.assertGreaterEqual(streamInputInfo.numRecords, 0)
+ for key in streamInputInfo.metadata():
+ self.assertIsNotNone(streamInputInfo.metadata()[key])
+ self.assertIsNotNone(streamInputInfo.metadataDescription())
+
+ for outputOpId in info.outputOperationInfos():
+ outputInfo = info.outputOperationInfos()[outputOpId]
+ self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0)
+ self.assertGreaterEqual(outputInfo.id(), 0)
+ self.assertIsNotNone(outputInfo.name())
+ self.assertIsNotNone(outputInfo.description())
+ self.assertGreaterEqual(outputInfo.startTime(), -1)
+ self.assertGreaterEqual(outputInfo.endTime(), -1)
+ self.assertIsNone(outputInfo.failureReason())
+
+ self.assertEqual(info.schedulingDelay(), -1)
+ self.assertEqual(info.processingDelay(), -1)
+ self.assertEqual(info.totalDelay(), -1)
+ self.assertEqual(info.numRecords(), 0)
+
+ self.assertGreaterEqual(len(batchInfosStarted), 4)
+ for info in batchInfosStarted:
+ self.assertGreaterEqual(info.batchTime().milliseconds(), 0)
+ self.assertGreaterEqual(info.submissionTime(), 0)
+
+ for streamId in info.streamIdToInputInfo():
+ streamInputInfo = info.streamIdToInputInfo()[streamId]
+ self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0)
+ self.assertGreaterEqual(streamInputInfo.numRecords, 0)
+ for key in streamInputInfo.metadata():
+ self.assertIsNotNone(streamInputInfo.metadata()[key])
+ self.assertIsNotNone(streamInputInfo.metadataDescription())
+
+ for outputOpId in info.outputOperationInfos():
+ outputInfo = info.outputOperationInfos()[outputOpId]
+ self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0)
+ self.assertGreaterEqual(outputInfo.id(), 0)
+ self.assertIsNotNone(outputInfo.name())
+ self.assertIsNotNone(outputInfo.description())
+ self.assertGreaterEqual(outputInfo.startTime(), -1)
+ self.assertGreaterEqual(outputInfo.endTime(), -1)
+ self.assertIsNone(outputInfo.failureReason())
+
+ self.assertGreaterEqual(info.schedulingDelay(), 0)
+ self.assertEqual(info.processingDelay(), -1)
+ self.assertEqual(info.totalDelay(), -1)
+ self.assertEqual(info.numRecords(), 0)
+
+ self.assertGreaterEqual(len(batchInfosCompleted), 4)
+ for info in batchInfosCompleted:
+ self.assertGreaterEqual(info.batchTime().milliseconds(), 0)
+ self.assertGreaterEqual(info.submissionTime(), 0)
+
+ for streamId in info.streamIdToInputInfo():
+ streamInputInfo = info.streamIdToInputInfo()[streamId]
+ self.assertGreaterEqual(streamInputInfo.inputStreamId(), 0)
+ self.assertGreaterEqual(streamInputInfo.numRecords, 0)
+ for key in streamInputInfo.metadata():
+ self.assertIsNotNone(streamInputInfo.metadata()[key])
+ self.assertIsNotNone(streamInputInfo.metadataDescription())
+
+ for outputOpId in info.outputOperationInfos():
+ outputInfo = info.outputOperationInfos()[outputOpId]
+ self.assertGreaterEqual(outputInfo.batchTime().milliseconds(), 0)
+ self.assertGreaterEqual(outputInfo.id(), 0)
+ self.assertIsNotNone(outputInfo.name())
+ self.assertIsNotNone(outputInfo.description())
+ self.assertGreaterEqual(outputInfo.startTime(), 0)
+ self.assertGreaterEqual(outputInfo.endTime(), 0)
+ self.assertIsNone(outputInfo.failureReason())
+
+ self.assertGreaterEqual(info.schedulingDelay(), 0)
+ self.assertGreaterEqual(info.processingDelay(), 0)
+ self.assertGreaterEqual(info.totalDelay(), 0)
+ self.assertEqual(info.numRecords(), 0)
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.streaming.tests.test_listener import *
+
+ try:
+ import xmlrunner
+ unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
+ except ImportError:
+ unittest.main(verbosity=2)
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index b61643eb0a16e..98b505c9046be 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -147,8 +147,8 @@ def __init__(self):
@classmethod
def _getOrCreate(cls):
"""Internal function to get or create global BarrierTaskContext."""
- if cls._taskContext is None:
- cls._taskContext = BarrierTaskContext()
+ if not isinstance(cls._taskContext, BarrierTaskContext):
+ cls._taskContext = object.__new__(cls)
return cls._taskContext
@classmethod
diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py
deleted file mode 100644
index 5b43729f9ebb1..0000000000000
--- a/python/pyspark/test_serializers.py
+++ /dev/null
@@ -1,90 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import io
-import math
-import struct
-import sys
-import unittest
-
-try:
- import xmlrunner
-except ImportError:
- xmlrunner = None
-
-from pyspark import serializers
-
-
-def read_int(b):
- return struct.unpack("!i", b)[0]
-
-
-def write_int(i):
- return struct.pack("!i", i)
-
-
-class SerializersTest(unittest.TestCase):
-
- def test_chunked_stream(self):
- original_bytes = bytearray(range(100))
- for data_length in [1, 10, 100]:
- for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]:
- dest = ByteArrayOutput()
- stream_out = serializers.ChunkedStream(dest, buffer_length)
- stream_out.write(original_bytes[:data_length])
- stream_out.close()
- num_chunks = int(math.ceil(float(data_length) / buffer_length))
- # length for each chunk, and a final -1 at the very end
- exp_size = (num_chunks + 1) * 4 + data_length
- self.assertEqual(len(dest.buffer), exp_size)
- dest_pos = 0
- data_pos = 0
- for chunk_idx in range(num_chunks):
- chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)])
- if chunk_idx == num_chunks - 1:
- exp_length = data_length % buffer_length
- if exp_length == 0:
- exp_length = buffer_length
- else:
- exp_length = buffer_length
- self.assertEqual(chunk_length, exp_length)
- dest_pos += 4
- dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length]
- orig_chunk = original_bytes[data_pos:data_pos + chunk_length]
- self.assertEqual(dest_chunk, orig_chunk)
- dest_pos += chunk_length
- data_pos += chunk_length
- # ends with a -1
- self.assertEqual(dest.buffer[-4:], write_int(-1))
-
-
-class ByteArrayOutput(object):
- def __init__(self):
- self.buffer = bytearray()
-
- def write(self, b):
- self.buffer += b
-
- def close(self):
- pass
-
-if __name__ == '__main__':
- from pyspark.test_serializers import *
- if xmlrunner:
- unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
- else:
- unittest.main(verbosity=2)
diff --git a/python/pyspark/testing/__init__.py b/python/pyspark/testing/__init__.py
new file mode 100644
index 0000000000000..12bdf0d0175b6
--- /dev/null
+++ b/python/pyspark/testing/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/testing/mllibutils.py b/python/pyspark/testing/mllibutils.py
new file mode 100644
index 0000000000000..25f1bba8d37ac
--- /dev/null
+++ b/python/pyspark/testing/mllibutils.py
@@ -0,0 +1,35 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark import SparkContext
+from pyspark.serializers import PickleSerializer
+from pyspark.sql import SparkSession
+
+
+def make_serializer():
+ return PickleSerializer()
+
+
+class MLlibTestCase(unittest.TestCase):
+ def setUp(self):
+ self.sc = SparkContext('local[4]', "MLlib tests")
+ self.spark = SparkSession(self.sc)
+
+ def tearDown(self):
+ self.spark.stop()
diff --git a/python/pyspark/testing/mlutils.py b/python/pyspark/testing/mlutils.py
new file mode 100644
index 0000000000000..12bf650a28ee1
--- /dev/null
+++ b/python/pyspark/testing/mlutils.py
@@ -0,0 +1,161 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+
+from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer
+from pyspark.ml.param import Param, Params, TypeConverters
+from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
+from pyspark.ml.wrapper import _java2py
+from pyspark.sql import DataFrame, SparkSession
+from pyspark.sql.types import DoubleType
+from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase
+
+
+def check_params(test_self, py_stage, check_params_exist=True):
+ """
+ Checks common requirements for Params.params:
+ - set of params exist in Java and Python and are ordered by names
+ - param parent has the same UID as the object's UID
+ - default param value from Java matches value in Python
+ - optionally check if all params from Java also exist in Python
+ """
+ py_stage_str = "%s %s" % (type(py_stage), py_stage)
+ if not hasattr(py_stage, "_to_java"):
+ return
+ java_stage = py_stage._to_java()
+ if java_stage is None:
+ return
+ test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str)
+ if check_params_exist:
+ param_names = [p.name for p in py_stage.params]
+ java_params = list(java_stage.params())
+ java_param_names = [jp.name() for jp in java_params]
+ test_self.assertEqual(
+ param_names, sorted(java_param_names),
+ "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s"
+ % (py_stage_str, java_param_names, param_names))
+ for p in py_stage.params:
+ test_self.assertEqual(p.parent, py_stage.uid)
+ java_param = java_stage.getParam(p.name)
+ py_has_default = py_stage.hasDefault(p)
+ java_has_default = java_stage.hasDefault(java_param)
+ test_self.assertEqual(py_has_default, java_has_default,
+ "Default value mismatch of param %s for Params %s"
+ % (p.name, str(py_stage)))
+ if py_has_default:
+ if p.name == "seed":
+ continue # Random seeds between Spark and PySpark are different
+ java_default = _java2py(test_self.sc,
+ java_stage.clear(java_param).getOrDefault(java_param))
+ py_stage._clear(p)
+ py_default = py_stage.getOrDefault(p)
+ # equality test for NaN is always False
+ if isinstance(java_default, float) and np.isnan(java_default):
+ java_default = "NaN"
+ py_default = "NaN" if np.isnan(py_default) else "not NaN"
+ test_self.assertEqual(
+ java_default, py_default,
+ "Java default %s != python default %s of param %s for Params %s"
+ % (str(java_default), str(py_default), p.name, str(py_stage)))
+
+
+class SparkSessionTestCase(PySparkTestCase):
+ @classmethod
+ def setUpClass(cls):
+ PySparkTestCase.setUpClass()
+ cls.spark = SparkSession(cls.sc)
+
+ @classmethod
+ def tearDownClass(cls):
+ PySparkTestCase.tearDownClass()
+ cls.spark.stop()
+
+
+class MockDataset(DataFrame):
+
+ def __init__(self):
+ self.index = 0
+
+
+class HasFake(Params):
+
+ def __init__(self):
+ super(HasFake, self).__init__()
+ self.fake = Param(self, "fake", "fake param")
+
+ def getFake(self):
+ return self.getOrDefault(self.fake)
+
+
+class MockTransformer(Transformer, HasFake):
+
+ def __init__(self):
+ super(MockTransformer, self).__init__()
+ self.dataset_index = None
+
+ def _transform(self, dataset):
+ self.dataset_index = dataset.index
+ dataset.index += 1
+ return dataset
+
+
+class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):
+
+ shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
+ "data in a DataFrame",
+ typeConverter=TypeConverters.toFloat)
+
+ def __init__(self, shiftVal=1):
+ super(MockUnaryTransformer, self).__init__()
+ self._setDefault(shift=1)
+ self._set(shift=shiftVal)
+
+ def getShift(self):
+ return self.getOrDefault(self.shift)
+
+ def setShift(self, shift):
+ self._set(shift=shift)
+
+ def createTransformFunc(self):
+ shiftVal = self.getShift()
+ return lambda x: x + shiftVal
+
+ def outputDataType(self):
+ return DoubleType()
+
+ def validateInputType(self, inputType):
+ if inputType != DoubleType():
+ raise TypeError("Bad input type: {}. ".format(inputType) +
+ "Requires Double.")
+
+
+class MockEstimator(Estimator, HasFake):
+
+ def __init__(self):
+ super(MockEstimator, self).__init__()
+ self.dataset_index = None
+
+ def _fit(self, dataset):
+ self.dataset_index = dataset.index
+ model = MockModel()
+ self._copyValues(model)
+ return model
+
+
+class MockModel(MockTransformer, Model, HasFake):
+ pass
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
new file mode 100644
index 0000000000000..afc40ccf4139d
--- /dev/null
+++ b/python/pyspark/testing/sqlutils.py
@@ -0,0 +1,268 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+import os
+import shutil
+import tempfile
+from contextlib import contextmanager
+
+from pyspark.sql import SparkSession
+from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
+from pyspark.testing.utils import ReusedPySparkTestCase
+from pyspark.util import _exception_message
+
+
+pandas_requirement_message = None
+try:
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+except ImportError as e:
+ # If Pandas version requirement is not satisfied, skip related tests.
+ pandas_requirement_message = _exception_message(e)
+
+pyarrow_requirement_message = None
+try:
+ from pyspark.sql.utils import require_minimum_pyarrow_version
+ require_minimum_pyarrow_version()
+except ImportError as e:
+ # If Arrow version requirement is not satisfied, skip related tests.
+ pyarrow_requirement_message = _exception_message(e)
+
+test_not_compiled_message = None
+try:
+ from pyspark.sql.utils import require_test_compiled
+ require_test_compiled()
+except Exception as e:
+ test_not_compiled_message = _exception_message(e)
+
+have_pandas = pandas_requirement_message is None
+have_pyarrow = pyarrow_requirement_message is None
+test_compiled = test_not_compiled_message is None
+
+
+class UTCOffsetTimezone(datetime.tzinfo):
+ """
+ Specifies timezone in UTC offset
+ """
+
+ def __init__(self, offset=0):
+ self.ZERO = datetime.timedelta(hours=offset)
+
+ def utcoffset(self, dt):
+ return self.ZERO
+
+ def dst(self, dt):
+ return self.ZERO
+
+
+class ExamplePointUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return 'pyspark.sql.tests'
+
+ @classmethod
+ def scalaUDT(cls):
+ return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+ """
+ An example class to demonstrate UDT in Scala, Java, and Python.
+ """
+
+ __UDT__ = ExamplePointUDT()
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __repr__(self):
+ return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+ def __str__(self):
+ return "(%s,%s)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and \
+ other.x == self.x and other.y == self.y
+
+
+class PythonOnlyUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return '__main__'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return PythonOnlyPoint(datum[0], datum[1])
+
+ @staticmethod
+ def foo():
+ pass
+
+ @property
+ def props(self):
+ return {}
+
+
+class PythonOnlyPoint(ExamplePoint):
+ """
+ An example class to demonstrate UDT in only Python
+ """
+ __UDT__ = PythonOnlyUDT()
+
+
+class MyObject(object):
+ def __init__(self, key, value):
+ self.key = key
+ self.value = value
+
+
+class SQLTestUtils(object):
+ """
+ This util assumes the instance of this to have 'spark' attribute, having a spark session.
+ It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
+ the implementation of this class has 'spark' attribute.
+ """
+
+ @contextmanager
+ def sql_conf(self, pairs):
+ """
+ A convenient context manager to test some configuration specific logic. This sets
+ `value` to the configuration `key` and then restores it back when it exits.
+ """
+ assert isinstance(pairs, dict), "pairs should be a dictionary."
+ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+ keys = pairs.keys()
+ new_values = pairs.values()
+ old_values = [self.spark.conf.get(key, None) for key in keys]
+ for key, new_value in zip(keys, new_values):
+ self.spark.conf.set(key, new_value)
+ try:
+ yield
+ finally:
+ for key, old_value in zip(keys, old_values):
+ if old_value is None:
+ self.spark.conf.unset(key)
+ else:
+ self.spark.conf.set(key, old_value)
+
+ @contextmanager
+ def database(self, *databases):
+ """
+ A convenient context manager to test with some specific databases. This drops the given
+ databases if it exists and sets current database to "default" when it exits.
+ """
+ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+ try:
+ yield
+ finally:
+ for db in databases:
+ self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
+ self.spark.catalog.setCurrentDatabase("default")
+
+ @contextmanager
+ def table(self, *tables):
+ """
+ A convenient context manager to test with some specific tables. This drops the given tables
+ if it exists.
+ """
+ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+ try:
+ yield
+ finally:
+ for t in tables:
+ self.spark.sql("DROP TABLE IF EXISTS %s" % t)
+
+ @contextmanager
+ def tempView(self, *views):
+ """
+ A convenient context manager to test with some specific views. This drops the given views
+ if it exists.
+ """
+ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+ try:
+ yield
+ finally:
+ for v in views:
+ self.spark.catalog.dropTempView(v)
+
+ @contextmanager
+ def function(self, *functions):
+ """
+ A convenient context manager to test with some specific functions. This drops the given
+ functions if it exists.
+ """
+ assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
+
+ try:
+ yield
+ finally:
+ for f in functions:
+ self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
+
+
+class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
+ @classmethod
+ def setUpClass(cls):
+ super(ReusedSQLTestCase, cls).setUpClass()
+ cls.spark = SparkSession(cls.sc)
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ cls.df = cls.spark.createDataFrame(cls.testData)
+
+ @classmethod
+ def tearDownClass(cls):
+ super(ReusedSQLTestCase, cls).tearDownClass()
+ cls.spark.stop()
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+ def assertPandasEqual(self, expected, result):
+ msg = ("DataFrames are not equal: " +
+ "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
+ "\n\nResult:\n%s\n%s" % (result, result.dtypes))
+ self.assertTrue(expected.equals(result), msg=msg)
diff --git a/python/pyspark/testing/streamingutils.py b/python/pyspark/testing/streamingutils.py
new file mode 100644
index 0000000000000..85a2fa14b936c
--- /dev/null
+++ b/python/pyspark/testing/streamingutils.py
@@ -0,0 +1,190 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import glob
+import os
+import tempfile
+import time
+import unittest
+
+from pyspark import SparkConf, SparkContext, RDD
+from pyspark.streaming import StreamingContext
+
+
+def search_kinesis_asl_assembly_jar():
+ kinesis_asl_assembly_dir = os.path.join(
+ os.environ["SPARK_HOME"], "external/kinesis-asl-assembly")
+
+ # We should ignore the following jars
+ ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar")
+
+ # Search jar in the project dir using the jar name_prefix for both sbt build and maven
+ # build because the artifact jars are in different directories.
+ name_prefix = "spark-streaming-kinesis-asl-assembly"
+ sbt_build = glob.glob(os.path.join(
+ kinesis_asl_assembly_dir, "target/scala-*/%s-*.jar" % name_prefix))
+ maven_build = glob.glob(os.path.join(
+ kinesis_asl_assembly_dir, "target/%s_*.jar" % name_prefix))
+ jar_paths = sbt_build + maven_build
+ jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)]
+
+ if not jars:
+ return None
+ elif len(jars) > 1:
+ raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please "
+ "remove all but one") % (", ".join(jars)))
+ else:
+ return jars[0]
+
+
+# Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py
+kinesis_test_environ_var = "ENABLE_KINESIS_TESTS"
+should_skip_kinesis_tests = not os.environ.get(kinesis_test_environ_var) == '1'
+
+if should_skip_kinesis_tests:
+ kinesis_requirement_message = (
+ "Skipping all Kinesis Python tests as environmental variable 'ENABLE_KINESIS_TESTS' "
+ "was not set.")
+else:
+ kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar()
+ if kinesis_asl_assembly_jar is None:
+ kinesis_requirement_message = (
+ "Skipping all Kinesis Python tests as the optional Kinesis project was "
+ "not compiled into a JAR. To run these tests, "
+ "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package "
+ "streaming-kinesis-asl-assembly/assembly' or "
+ "'build/mvn -Pkinesis-asl package' before running this test.")
+ else:
+ existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+ jars_args = "--jars %s" % kinesis_asl_assembly_jar
+ os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args])
+ kinesis_requirement_message = None
+
+should_test_kinesis = kinesis_requirement_message is None
+
+
+class PySparkStreamingTestCase(unittest.TestCase):
+
+ timeout = 30 # seconds
+ duration = .5
+
+ @classmethod
+ def setUpClass(cls):
+ class_name = cls.__name__
+ conf = SparkConf().set("spark.default.parallelism", 1)
+ cls.sc = SparkContext(appName=class_name, conf=conf)
+ cls.sc.setCheckpointDir(tempfile.mkdtemp())
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sc.stop()
+ # Clean up in the JVM just in case there has been some issues in Python API
+ try:
+ jSparkContextOption = SparkContext._jvm.SparkContext.get()
+ if jSparkContextOption.nonEmpty():
+ jSparkContextOption.get().stop()
+ except:
+ pass
+
+ def setUp(self):
+ self.ssc = StreamingContext(self.sc, self.duration)
+
+ def tearDown(self):
+ if self.ssc is not None:
+ self.ssc.stop(False)
+ # Clean up in the JVM just in case there has been some issues in Python API
+ try:
+ jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
+ if jStreamingContextOption.nonEmpty():
+ jStreamingContextOption.get().stop(False)
+ except:
+ pass
+
+ def wait_for(self, result, n):
+ start_time = time.time()
+ while len(result) < n and time.time() - start_time < self.timeout:
+ time.sleep(0.01)
+ if len(result) < n:
+ print("timeout after", self.timeout)
+
+ def _take(self, dstream, n):
+ """
+ Return the first `n` elements in the stream (will start and stop).
+ """
+ results = []
+
+ def take(_, rdd):
+ if rdd and len(results) < n:
+ results.extend(rdd.take(n - len(results)))
+
+ dstream.foreachRDD(take)
+
+ self.ssc.start()
+ self.wait_for(results, n)
+ return results
+
+ def _collect(self, dstream, n, block=True):
+ """
+ Collect each RDDs into the returned list.
+
+ :return: list, which will have the collected items.
+ """
+ result = []
+
+ def get_output(_, rdd):
+ if rdd and len(result) < n:
+ r = rdd.collect()
+ if r:
+ result.append(r)
+
+ dstream.foreachRDD(get_output)
+
+ if not block:
+ return result
+
+ self.ssc.start()
+ self.wait_for(result, n)
+ return result
+
+ def _test_func(self, input, func, expected, sort=False, input2=None):
+ """
+ @param input: dataset for the test. This should be list of lists.
+ @param func: wrapped function. This function should return PythonDStream object.
+ @param expected: expected output for this testcase.
+ """
+ if not isinstance(input[0], RDD):
+ input = [self.sc.parallelize(d, 1) for d in input]
+ input_stream = self.ssc.queueStream(input)
+ if input2 and not isinstance(input2[0], RDD):
+ input2 = [self.sc.parallelize(d, 1) for d in input2]
+ input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
+
+ # Apply test function to stream.
+ if input2:
+ stream = func(input_stream, input_stream2)
+ else:
+ stream = func(input_stream)
+
+ result = self._collect(stream, len(expected))
+ if sort:
+ self._sort_result_based_on_key(result)
+ self._sort_result_based_on_key(expected)
+ self.assertEqual(expected, result)
+
+ def _sort_result_based_on_key(self, outputs):
+ """Sort the list based on first value."""
+ for output in outputs:
+ output.sort(key=lambda x: x[0])
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
new file mode 100644
index 0000000000000..7df0acae026f3
--- /dev/null
+++ b/python/pyspark/testing/utils.py
@@ -0,0 +1,102 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import struct
+import sys
+import unittest
+
+from pyspark import SparkContext, SparkConf
+
+
+have_scipy = False
+have_numpy = False
+try:
+ import scipy.sparse
+ have_scipy = True
+except:
+ # No SciPy, but that's okay, we'll skip those tests
+ pass
+try:
+ import numpy as np
+ have_numpy = True
+except:
+ # No NumPy, but that's okay, we'll skip those tests
+ pass
+
+
+SPARK_HOME = os.environ["SPARK_HOME"]
+
+
+def read_int(b):
+ return struct.unpack("!i", b)[0]
+
+
+def write_int(i):
+ return struct.pack("!i", i)
+
+
+class QuietTest(object):
+ def __init__(self, sc):
+ self.log4j = sc._jvm.org.apache.log4j
+
+ def __enter__(self):
+ self.old_level = self.log4j.LogManager.getRootLogger().getLevel()
+ self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.log4j.LogManager.getRootLogger().setLevel(self.old_level)
+
+
+class PySparkTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self._old_sys_path = list(sys.path)
+ class_name = self.__class__.__name__
+ self.sc = SparkContext('local[4]', class_name)
+
+ def tearDown(self):
+ self.sc.stop()
+ sys.path = self._old_sys_path
+
+
+class ReusedPySparkTestCase(unittest.TestCase):
+
+ @classmethod
+ def conf(cls):
+ """
+ Override this in subclasses to supply a more specific conf
+ """
+ return SparkConf()
+
+ @classmethod
+ def setUpClass(cls):
+ cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf())
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sc.stop()
+
+
+class ByteArrayOutput(object):
+ def __init__(self):
+ self.buffer = bytearray()
+
+ def write(self, b):
+ self.buffer += b
+
+ def close(self):
+ pass
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
deleted file mode 100644
index c15d443ebbba9..0000000000000
--- a/python/pyspark/tests.py
+++ /dev/null
@@ -1,2522 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-Unit tests for PySpark; additional tests are implemented as doctests in
-individual modules.
-"""
-
-from array import array
-from glob import glob
-import os
-import re
-import shutil
-import subprocess
-import sys
-import tempfile
-import time
-import zipfile
-import random
-import threading
-import hashlib
-
-from py4j.protocol import Py4JJavaError
-xmlrunner = None
-
-if sys.version_info[:2] <= (2, 6):
- try:
- import unittest2 as unittest
- except ImportError:
- sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
- sys.exit(1)
-else:
- import unittest
- if sys.version_info[0] >= 3:
- xrange = range
- basestring = str
-
-import unishark
-
-if sys.version >= "3":
- from io import StringIO
-else:
- from StringIO import StringIO
-
-
-from pyspark import keyword_only
-from pyspark.conf import SparkConf
-from pyspark.context import SparkContext
-from pyspark.rdd import RDD
-from pyspark.files import SparkFiles
-from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
- CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \
- PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \
- FlattenedValuesSerializer
-from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
-from pyspark import shuffle
-from pyspark.profiler import BasicProfiler
-from pyspark.taskcontext import BarrierTaskContext, TaskContext
-
-_have_scipy = False
-_have_numpy = False
-try:
- import scipy.sparse
- _have_scipy = True
-except:
- # No SciPy, but that's okay, we'll skip those tests
- pass
-try:
- import numpy as np
- _have_numpy = True
-except:
- # No NumPy, but that's okay, we'll skip those tests
- pass
-
-
-SPARK_HOME = os.environ["SPARK_HOME"]
-
-
-class MergerTests(unittest.TestCase):
-
- def setUp(self):
- self.N = 1 << 12
- self.l = [i for i in xrange(self.N)]
- self.data = list(zip(self.l, self.l))
- self.agg = Aggregator(lambda x: [x],
- lambda x, y: x.append(y) or x,
- lambda x, y: x.extend(y) or x)
-
- def test_small_dataset(self):
- m = ExternalMerger(self.agg, 1000)
- m.mergeValues(self.data)
- self.assertEqual(m.spills, 0)
- self.assertEqual(sum(sum(v) for k, v in m.items()),
- sum(xrange(self.N)))
-
- m = ExternalMerger(self.agg, 1000)
- m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data))
- self.assertEqual(m.spills, 0)
- self.assertEqual(sum(sum(v) for k, v in m.items()),
- sum(xrange(self.N)))
-
- def test_medium_dataset(self):
- m = ExternalMerger(self.agg, 20)
- m.mergeValues(self.data)
- self.assertTrue(m.spills >= 1)
- self.assertEqual(sum(sum(v) for k, v in m.items()),
- sum(xrange(self.N)))
-
- m = ExternalMerger(self.agg, 10)
- m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3))
- self.assertTrue(m.spills >= 1)
- self.assertEqual(sum(sum(v) for k, v in m.items()),
- sum(xrange(self.N)) * 3)
-
- def test_huge_dataset(self):
- m = ExternalMerger(self.agg, 5, partitions=3)
- m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10))
- self.assertTrue(m.spills >= 1)
- self.assertEqual(sum(len(v) for k, v in m.items()),
- self.N * 10)
- m._cleanup()
-
- def test_group_by_key(self):
-
- def gen_data(N, step):
- for i in range(1, N + 1, step):
- for j in range(i):
- yield (i, [j])
-
- def gen_gs(N, step=1):
- return shuffle.GroupByKey(gen_data(N, step))
-
- self.assertEqual(1, len(list(gen_gs(1))))
- self.assertEqual(2, len(list(gen_gs(2))))
- self.assertEqual(100, len(list(gen_gs(100))))
- self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)])
- self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100)))
-
- for k, vs in gen_gs(50002, 10000):
- self.assertEqual(k, len(vs))
- self.assertEqual(list(range(k)), list(vs))
-
- ser = PickleSerializer()
- l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
- for k, vs in l:
- self.assertEqual(k, len(vs))
- self.assertEqual(list(range(k)), list(vs))
-
- def test_stopiteration_is_raised(self):
-
- def stopit(*args, **kwargs):
- raise StopIteration()
-
- def legit_create_combiner(x):
- return [x]
-
- def legit_merge_value(x, y):
- return x.append(y) or x
-
- def legit_merge_combiners(x, y):
- return x.extend(y) or x
-
- data = [(x % 2, x) for x in range(100)]
-
- # wrong create combiner
- m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20)
- with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
- m.mergeValues(data)
-
- # wrong merge value
- m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20)
- with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
- m.mergeValues(data)
-
- # wrong merge combiners
- m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20)
- with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
- m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
-
-
-class SorterTests(unittest.TestCase):
- def test_in_memory_sort(self):
- l = list(range(1024))
- random.shuffle(l)
- sorter = ExternalSorter(1024)
- self.assertEqual(sorted(l), list(sorter.sorted(l)))
- self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
- self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
- self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
- list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
-
- def test_external_sort(self):
- class CustomizedSorter(ExternalSorter):
- def _next_limit(self):
- return self.memory_limit
- l = list(range(1024))
- random.shuffle(l)
- sorter = CustomizedSorter(1)
- self.assertEqual(sorted(l), list(sorter.sorted(l)))
- self.assertGreater(shuffle.DiskBytesSpilled, 0)
- last = shuffle.DiskBytesSpilled
- self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
- self.assertGreater(shuffle.DiskBytesSpilled, last)
- last = shuffle.DiskBytesSpilled
- self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
- self.assertGreater(shuffle.DiskBytesSpilled, last)
- last = shuffle.DiskBytesSpilled
- self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
- list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
- self.assertGreater(shuffle.DiskBytesSpilled, last)
-
- def test_external_sort_in_rdd(self):
- conf = SparkConf().set("spark.python.worker.memory", "1m")
- sc = SparkContext(conf=conf)
- l = list(range(10240))
- random.shuffle(l)
- rdd = sc.parallelize(l, 4)
- self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
- sc.stop()
-
-
-class SerializationTestCase(unittest.TestCase):
-
- def test_namedtuple(self):
- from collections import namedtuple
- from pickle import dumps, loads
- P = namedtuple("P", "x y")
- p1 = P(1, 3)
- p2 = loads(dumps(p1, 2))
- self.assertEqual(p1, p2)
-
- from pyspark.cloudpickle import dumps
- P2 = loads(dumps(P))
- p3 = P2(1, 3)
- self.assertEqual(p1, p3)
-
- def test_itemgetter(self):
- from operator import itemgetter
- ser = CloudPickleSerializer()
- d = range(10)
- getter = itemgetter(1)
- getter2 = ser.loads(ser.dumps(getter))
- self.assertEqual(getter(d), getter2(d))
-
- getter = itemgetter(0, 3)
- getter2 = ser.loads(ser.dumps(getter))
- self.assertEqual(getter(d), getter2(d))
-
- def test_function_module_name(self):
- ser = CloudPickleSerializer()
- func = lambda x: x
- func2 = ser.loads(ser.dumps(func))
- self.assertEqual(func.__module__, func2.__module__)
-
- def test_attrgetter(self):
- from operator import attrgetter
- ser = CloudPickleSerializer()
-
- class C(object):
- def __getattr__(self, item):
- return item
- d = C()
- getter = attrgetter("a")
- getter2 = ser.loads(ser.dumps(getter))
- self.assertEqual(getter(d), getter2(d))
- getter = attrgetter("a", "b")
- getter2 = ser.loads(ser.dumps(getter))
- self.assertEqual(getter(d), getter2(d))
-
- d.e = C()
- getter = attrgetter("e.a")
- getter2 = ser.loads(ser.dumps(getter))
- self.assertEqual(getter(d), getter2(d))
- getter = attrgetter("e.a", "e.b")
- getter2 = ser.loads(ser.dumps(getter))
- self.assertEqual(getter(d), getter2(d))
-
- # Regression test for SPARK-3415
- def test_pickling_file_handles(self):
- # to be corrected with SPARK-11160
- if not xmlrunner:
- ser = CloudPickleSerializer()
- out1 = sys.stderr
- out2 = ser.loads(ser.dumps(out1))
- self.assertEqual(out1, out2)
-
- def test_func_globals(self):
-
- class Unpicklable(object):
- def __reduce__(self):
- raise Exception("not picklable")
-
- global exit
- exit = Unpicklable()
-
- ser = CloudPickleSerializer()
- self.assertRaises(Exception, lambda: ser.dumps(exit))
-
- def foo():
- sys.exit(0)
-
- self.assertTrue("exit" in foo.__code__.co_names)
- ser.dumps(foo)
-
- def test_compressed_serializer(self):
- ser = CompressedSerializer(PickleSerializer())
- try:
- from StringIO import StringIO
- except ImportError:
- from io import BytesIO as StringIO
- io = StringIO()
- ser.dump_stream(["abc", u"123", range(5)], io)
- io.seek(0)
- self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
- ser.dump_stream(range(1000), io)
- io.seek(0)
- self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io)))
- io.close()
-
- def test_hash_serializer(self):
- hash(NoOpSerializer())
- hash(UTF8Deserializer())
- hash(PickleSerializer())
- hash(MarshalSerializer())
- hash(AutoSerializer())
- hash(BatchedSerializer(PickleSerializer()))
- hash(AutoBatchedSerializer(MarshalSerializer()))
- hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer()))
- hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer()))
- hash(CompressedSerializer(PickleSerializer()))
- hash(FlattenedValuesSerializer(PickleSerializer()))
-
-
-class QuietTest(object):
- def __init__(self, sc):
- self.log4j = sc._jvm.org.apache.log4j
-
- def __enter__(self):
- self.old_level = self.log4j.LogManager.getRootLogger().getLevel()
- self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL)
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.log4j.LogManager.getRootLogger().setLevel(self.old_level)
-
-
-class PySparkTestCase(unittest.TestCase):
-
- def setUp(self):
- self._old_sys_path = list(sys.path)
- class_name = self.__class__.__name__
- self.sc = SparkContext('local[4]', class_name)
-
- def tearDown(self):
- self.sc.stop()
- sys.path = self._old_sys_path
-
-
-class ReusedPySparkTestCase(unittest.TestCase):
-
- @classmethod
- def conf(cls):
- """
- Override this in subclasses to supply a more specific conf
- """
- return SparkConf()
-
- @classmethod
- def setUpClass(cls):
- cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf())
-
- @classmethod
- def tearDownClass(cls):
- cls.sc.stop()
-
-
-class CheckpointTests(ReusedPySparkTestCase):
-
- def setUp(self):
- self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
- os.unlink(self.checkpointDir.name)
- self.sc.setCheckpointDir(self.checkpointDir.name)
-
- def tearDown(self):
- shutil.rmtree(self.checkpointDir.name)
-
- def test_basic_checkpointing(self):
- parCollection = self.sc.parallelize([1, 2, 3, 4])
- flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
-
- self.assertFalse(flatMappedRDD.isCheckpointed())
- self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
-
- flatMappedRDD.checkpoint()
- result = flatMappedRDD.collect()
- time.sleep(1) # 1 second
- self.assertTrue(flatMappedRDD.isCheckpointed())
- self.assertEqual(flatMappedRDD.collect(), result)
- self.assertEqual("file:" + self.checkpointDir.name,
- os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile())))
-
- def test_checkpoint_and_restore(self):
- parCollection = self.sc.parallelize([1, 2, 3, 4])
- flatMappedRDD = parCollection.flatMap(lambda x: [x])
-
- self.assertFalse(flatMappedRDD.isCheckpointed())
- self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
-
- flatMappedRDD.checkpoint()
- flatMappedRDD.count() # forces a checkpoint to be computed
- time.sleep(1) # 1 second
-
- self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
- recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
- flatMappedRDD._jrdd_deserializer)
- self.assertEqual([1, 2, 3, 4], recovered.collect())
-
-
-class LocalCheckpointTests(ReusedPySparkTestCase):
-
- def test_basic_localcheckpointing(self):
- parCollection = self.sc.parallelize([1, 2, 3, 4])
- flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
-
- self.assertFalse(flatMappedRDD.isCheckpointed())
- self.assertFalse(flatMappedRDD.isLocallyCheckpointed())
-
- flatMappedRDD.localCheckpoint()
- result = flatMappedRDD.collect()
- time.sleep(1) # 1 second
- self.assertTrue(flatMappedRDD.isCheckpointed())
- self.assertTrue(flatMappedRDD.isLocallyCheckpointed())
- self.assertEqual(flatMappedRDD.collect(), result)
-
-
-class AddFileTests(PySparkTestCase):
-
- def test_add_py_file(self):
- # To ensure that we're actually testing addPyFile's effects, check that
- # this job fails due to `userlibrary` not being on the Python path:
- # disable logging in log4j temporarily
- def func(x):
- from userlibrary import UserClass
- return UserClass().hello()
- with QuietTest(self.sc):
- self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first)
-
- # Add the file, so the job should now succeed:
- path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
- self.sc.addPyFile(path)
- res = self.sc.parallelize(range(2)).map(func).first()
- self.assertEqual("Hello World!", res)
-
- def test_add_file_locally(self):
- path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
- self.sc.addFile(path)
- download_path = SparkFiles.get("hello.txt")
- self.assertNotEqual(path, download_path)
- with open(download_path) as test_file:
- self.assertEqual("Hello World!\n", test_file.readline())
-
- def test_add_file_recursively_locally(self):
- path = os.path.join(SPARK_HOME, "python/test_support/hello")
- self.sc.addFile(path, True)
- download_path = SparkFiles.get("hello")
- self.assertNotEqual(path, download_path)
- with open(download_path + "/hello.txt") as test_file:
- self.assertEqual("Hello World!\n", test_file.readline())
- with open(download_path + "/sub_hello/sub_hello.txt") as test_file:
- self.assertEqual("Sub Hello World!\n", test_file.readline())
-
- def test_add_py_file_locally(self):
- # To ensure that we're actually testing addPyFile's effects, check that
- # this fails due to `userlibrary` not being on the Python path:
- def func():
- from userlibrary import UserClass
- self.assertRaises(ImportError, func)
- path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
- self.sc.addPyFile(path)
- from userlibrary import UserClass
- self.assertEqual("Hello World!", UserClass().hello())
-
- def test_add_egg_file_locally(self):
- # To ensure that we're actually testing addPyFile's effects, check that
- # this fails due to `userlibrary` not being on the Python path:
- def func():
- from userlib import UserClass
- self.assertRaises(ImportError, func)
- path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip")
- self.sc.addPyFile(path)
- from userlib import UserClass
- self.assertEqual("Hello World from inside a package!", UserClass().hello())
-
- def test_overwrite_system_module(self):
- self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py"))
-
- import SimpleHTTPServer
- self.assertEqual("My Server", SimpleHTTPServer.__name__)
-
- def func(x):
- import SimpleHTTPServer
- return SimpleHTTPServer.__name__
-
- self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
-
-
-class TaskContextTests(PySparkTestCase):
-
- def setUp(self):
- self._old_sys_path = list(sys.path)
- class_name = self.__class__.__name__
- # Allow retries even though they are normally disabled in local mode
- self.sc = SparkContext('local[4, 2]', class_name)
-
- def test_stage_id(self):
- """Test the stage ids are available and incrementing as expected."""
- rdd = self.sc.parallelize(range(10))
- stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
- stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
- # Test using the constructor directly rather than the get()
- stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0]
- self.assertEqual(stage1 + 1, stage2)
- self.assertEqual(stage1 + 2, stage3)
- self.assertEqual(stage2 + 1, stage3)
-
- def test_partition_id(self):
- """Test the partition id."""
- rdd1 = self.sc.parallelize(range(10), 1)
- rdd2 = self.sc.parallelize(range(10), 2)
- pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect()
- pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect()
- self.assertEqual(0, pids1[0])
- self.assertEqual(0, pids1[9])
- self.assertEqual(0, pids2[0])
- self.assertEqual(1, pids2[9])
-
- def test_attempt_number(self):
- """Verify the attempt numbers are correctly reported."""
- rdd = self.sc.parallelize(range(10))
- # Verify a simple job with no failures
- attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect()
- map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers)
-
- def fail_on_first(x):
- """Fail on the first attempt so we get a positive attempt number"""
- tc = TaskContext.get()
- attempt_number = tc.attemptNumber()
- partition_id = tc.partitionId()
- attempt_id = tc.taskAttemptId()
- if attempt_number == 0 and partition_id == 0:
- raise Exception("Failing on first attempt")
- else:
- return [x, partition_id, attempt_number, attempt_id]
- result = rdd.map(fail_on_first).collect()
- # We should re-submit the first partition to it but other partitions should be attempt 0
- self.assertEqual([0, 0, 1], result[0][0:3])
- self.assertEqual([9, 3, 0], result[9][0:3])
- first_partition = filter(lambda x: x[1] == 0, result)
- map(lambda x: self.assertEqual(1, x[2]), first_partition)
- other_partitions = filter(lambda x: x[1] != 0, result)
- map(lambda x: self.assertEqual(0, x[2]), other_partitions)
- # The task attempt id should be different
- self.assertTrue(result[0][3] != result[9][3])
-
- def test_tc_on_driver(self):
- """Verify that getting the TaskContext on the driver returns None."""
- tc = TaskContext.get()
- self.assertTrue(tc is None)
-
- def test_get_local_property(self):
- """Verify that local properties set on the driver are available in TaskContext."""
- key = "testkey"
- value = "testvalue"
- self.sc.setLocalProperty(key, value)
- try:
- rdd = self.sc.parallelize(range(1), 1)
- prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0]
- self.assertEqual(prop1, value)
- prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
- self.assertTrue(prop2 is None)
- finally:
- self.sc.setLocalProperty(key, None)
-
- def test_barrier(self):
- """
- Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks
- within a stage.
- """
- rdd = self.sc.parallelize(range(10), 4)
-
- def f(iterator):
- yield sum(iterator)
-
- def context_barrier(x):
- tc = BarrierTaskContext.get()
- time.sleep(random.randint(1, 10))
- tc.barrier()
- return time.time()
-
- times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
- self.assertTrue(max(times) - min(times) < 1)
-
- def test_barrier_infos(self):
- """
- Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
- barrier stage.
- """
- rdd = self.sc.parallelize(range(10), 4)
-
- def f(iterator):
- yield sum(iterator)
-
- taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get()
- .getTaskInfos()).collect()
- self.assertTrue(len(taskInfos) == 4)
- self.assertTrue(len(taskInfos[0]) == 4)
-
-
-class RDDTests(ReusedPySparkTestCase):
-
- def test_range(self):
- self.assertEqual(self.sc.range(1, 1).count(), 0)
- self.assertEqual(self.sc.range(1, 0, -1).count(), 1)
- self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2)
-
- def test_id(self):
- rdd = self.sc.parallelize(range(10))
- id = rdd.id()
- self.assertEqual(id, rdd.id())
- rdd2 = rdd.map(str).filter(bool)
- id2 = rdd2.id()
- self.assertEqual(id + 1, id2)
- self.assertEqual(id2, rdd2.id())
-
- def test_empty_rdd(self):
- rdd = self.sc.emptyRDD()
- self.assertTrue(rdd.isEmpty())
-
- def test_sum(self):
- self.assertEqual(0, self.sc.emptyRDD().sum())
- self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())
-
- def test_to_localiterator(self):
- from time import sleep
- rdd = self.sc.parallelize([1, 2, 3])
- it = rdd.toLocalIterator()
- sleep(5)
- self.assertEqual([1, 2, 3], sorted(it))
-
- rdd2 = rdd.repartition(1000)
- it2 = rdd2.toLocalIterator()
- sleep(5)
- self.assertEqual([1, 2, 3], sorted(it2))
-
- def test_save_as_textfile_with_unicode(self):
- # Regression test for SPARK-970
- x = u"\u00A1Hola, mundo!"
- data = self.sc.parallelize([x])
- tempFile = tempfile.NamedTemporaryFile(delete=True)
- tempFile.close()
- data.saveAsTextFile(tempFile.name)
- raw_contents = b''.join(open(p, 'rb').read()
- for p in glob(tempFile.name + "/part-0000*"))
- self.assertEqual(x, raw_contents.strip().decode("utf-8"))
-
- def test_save_as_textfile_with_utf8(self):
- x = u"\u00A1Hola, mundo!"
- data = self.sc.parallelize([x.encode("utf-8")])
- tempFile = tempfile.NamedTemporaryFile(delete=True)
- tempFile.close()
- data.saveAsTextFile(tempFile.name)
- raw_contents = b''.join(open(p, 'rb').read()
- for p in glob(tempFile.name + "/part-0000*"))
- self.assertEqual(x, raw_contents.strip().decode('utf8'))
-
- def test_transforming_cartesian_result(self):
- # Regression test for SPARK-1034
- rdd1 = self.sc.parallelize([1, 2])
- rdd2 = self.sc.parallelize([3, 4])
- cart = rdd1.cartesian(rdd2)
- result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect()
-
- def test_transforming_pickle_file(self):
- # Regression test for SPARK-2601
- data = self.sc.parallelize([u"Hello", u"World!"])
- tempFile = tempfile.NamedTemporaryFile(delete=True)
- tempFile.close()
- data.saveAsPickleFile(tempFile.name)
- pickled_file = self.sc.pickleFile(tempFile.name)
- pickled_file.map(lambda x: x).collect()
-
- def test_cartesian_on_textfile(self):
- # Regression test for
- path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
- a = self.sc.textFile(path)
- result = a.cartesian(a).collect()
- (x, y) = result[0]
- self.assertEqual(u"Hello World!", x.strip())
- self.assertEqual(u"Hello World!", y.strip())
-
- def test_cartesian_chaining(self):
- # Tests for SPARK-16589
- rdd = self.sc.parallelize(range(10), 2)
- self.assertSetEqual(
- set(rdd.cartesian(rdd).cartesian(rdd).collect()),
- set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)])
- )
-
- self.assertSetEqual(
- set(rdd.cartesian(rdd.cartesian(rdd)).collect()),
- set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)])
- )
-
- self.assertSetEqual(
- set(rdd.cartesian(rdd.zip(rdd)).collect()),
- set([(x, (y, y)) for x in range(10) for y in range(10)])
- )
-
- def test_zip_chaining(self):
- # Tests for SPARK-21985
- rdd = self.sc.parallelize('abc', 2)
- self.assertSetEqual(
- set(rdd.zip(rdd).zip(rdd).collect()),
- set([((x, x), x) for x in 'abc'])
- )
- self.assertSetEqual(
- set(rdd.zip(rdd.zip(rdd)).collect()),
- set([(x, (x, x)) for x in 'abc'])
- )
-
- def test_deleting_input_files(self):
- # Regression test for SPARK-1025
- tempFile = tempfile.NamedTemporaryFile(delete=False)
- tempFile.write(b"Hello World!")
- tempFile.close()
- data = self.sc.textFile(tempFile.name)
- filtered_data = data.filter(lambda x: True)
- self.assertEqual(1, filtered_data.count())
- os.unlink(tempFile.name)
- with QuietTest(self.sc):
- self.assertRaises(Exception, lambda: filtered_data.count())
-
- def test_sampling_default_seed(self):
- # Test for SPARK-3995 (default seed setting)
- data = self.sc.parallelize(xrange(1000), 1)
- subset = data.takeSample(False, 10)
- self.assertEqual(len(subset), 10)
-
- def test_aggregate_mutable_zero_value(self):
- # Test for SPARK-9021; uses aggregate and treeAggregate to build dict
- # representing a counter of ints
- # NOTE: dict is used instead of collections.Counter for Python 2.6
- # compatibility
- from collections import defaultdict
-
- # Show that single or multiple partitions work
- data1 = self.sc.range(10, numSlices=1)
- data2 = self.sc.range(10, numSlices=2)
-
- def seqOp(x, y):
- x[y] += 1
- return x
-
- def comboOp(x, y):
- for key, val in y.items():
- x[key] += val
- return x
-
- counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp)
- counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp)
- counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
- counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
-
- ground_truth = defaultdict(int, dict((i, 1) for i in range(10)))
- self.assertEqual(counts1, ground_truth)
- self.assertEqual(counts2, ground_truth)
- self.assertEqual(counts3, ground_truth)
- self.assertEqual(counts4, ground_truth)
-
- def test_aggregate_by_key_mutable_zero_value(self):
- # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that
- # contains lists of all values for each key in the original RDD
-
- # list(range(...)) for Python 3.x compatibility (can't use * operator
- # on a range object)
- # list(zip(...)) for Python 3.x compatibility (want to parallelize a
- # collection, not a zip object)
- tuples = list(zip(list(range(10))*2, [1]*20))
- # Show that single or multiple partitions work
- data1 = self.sc.parallelize(tuples, 1)
- data2 = self.sc.parallelize(tuples, 2)
-
- def seqOp(x, y):
- x.append(y)
- return x
-
- def comboOp(x, y):
- x.extend(y)
- return x
-
- values1 = data1.aggregateByKey([], seqOp, comboOp).collect()
- values2 = data2.aggregateByKey([], seqOp, comboOp).collect()
- # Sort lists to ensure clean comparison with ground_truth
- values1.sort()
- values2.sort()
-
- ground_truth = [(i, [1]*2) for i in range(10)]
- self.assertEqual(values1, ground_truth)
- self.assertEqual(values2, ground_truth)
-
- def test_fold_mutable_zero_value(self):
- # Test for SPARK-9021; uses fold to merge an RDD of dict counters into
- # a single dict
- # NOTE: dict is used instead of collections.Counter for Python 2.6
- # compatibility
- from collections import defaultdict
-
- counts1 = defaultdict(int, dict((i, 1) for i in range(10)))
- counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8)))
- counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7)))
- counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6)))
- all_counts = [counts1, counts2, counts3, counts4]
- # Show that single or multiple partitions work
- data1 = self.sc.parallelize(all_counts, 1)
- data2 = self.sc.parallelize(all_counts, 2)
-
- def comboOp(x, y):
- for key, val in y.items():
- x[key] += val
- return x
-
- fold1 = data1.fold(defaultdict(int), comboOp)
- fold2 = data2.fold(defaultdict(int), comboOp)
-
- ground_truth = defaultdict(int)
- for counts in all_counts:
- for key, val in counts.items():
- ground_truth[key] += val
- self.assertEqual(fold1, ground_truth)
- self.assertEqual(fold2, ground_truth)
-
- def test_fold_by_key_mutable_zero_value(self):
- # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains
- # lists of all values for each key in the original RDD
-
- tuples = [(i, range(i)) for i in range(10)]*2
- # Show that single or multiple partitions work
- data1 = self.sc.parallelize(tuples, 1)
- data2 = self.sc.parallelize(tuples, 2)
-
- def comboOp(x, y):
- x.extend(y)
- return x
-
- values1 = data1.foldByKey([], comboOp).collect()
- values2 = data2.foldByKey([], comboOp).collect()
- # Sort lists to ensure clean comparison with ground_truth
- values1.sort()
- values2.sort()
-
- # list(range(...)) for Python 3.x compatibility
- ground_truth = [(i, list(range(i))*2) for i in range(10)]
- self.assertEqual(values1, ground_truth)
- self.assertEqual(values2, ground_truth)
-
- def test_aggregate_by_key(self):
- data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
-
- def seqOp(x, y):
- x.add(y)
- return x
-
- def combOp(x, y):
- x |= y
- return x
-
- sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect())
- self.assertEqual(3, len(sets))
- self.assertEqual(set([1]), sets[1])
- self.assertEqual(set([2]), sets[3])
- self.assertEqual(set([1, 3]), sets[5])
-
- def test_itemgetter(self):
- rdd = self.sc.parallelize([range(10)])
- from operator import itemgetter
- self.assertEqual([1], rdd.map(itemgetter(1)).collect())
- self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect())
-
- def test_namedtuple_in_rdd(self):
- from collections import namedtuple
- Person = namedtuple("Person", "id firstName lastName")
- jon = Person(1, "Jon", "Doe")
- jane = Person(2, "Jane", "Doe")
- theDoes = self.sc.parallelize([jon, jane])
- self.assertEqual([jon, jane], theDoes.collect())
-
- def test_large_broadcast(self):
- N = 10000
- data = [[float(i) for i in range(300)] for i in range(N)]
- bdata = self.sc.broadcast(data) # 27MB
- m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
- self.assertEqual(N, m)
-
- def test_unpersist(self):
- N = 1000
- data = [[float(i) for i in range(300)] for i in range(N)]
- bdata = self.sc.broadcast(data) # 3MB
- bdata.unpersist()
- m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
- self.assertEqual(N, m)
- bdata.destroy()
- try:
- self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
- except Exception as e:
- pass
- else:
- raise Exception("job should fail after destroy the broadcast")
-
- def test_multiple_broadcasts(self):
- N = 1 << 21
- b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
- r = list(range(1 << 15))
- random.shuffle(r)
- s = str(r).encode()
- checksum = hashlib.md5(s).hexdigest()
- b2 = self.sc.broadcast(s)
- r = list(set(self.sc.parallelize(range(10), 10).map(
- lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
- self.assertEqual(1, len(r))
- size, csum = r[0]
- self.assertEqual(N, size)
- self.assertEqual(checksum, csum)
-
- random.shuffle(r)
- s = str(r).encode()
- checksum = hashlib.md5(s).hexdigest()
- b2 = self.sc.broadcast(s)
- r = list(set(self.sc.parallelize(range(10), 10).map(
- lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
- self.assertEqual(1, len(r))
- size, csum = r[0]
- self.assertEqual(N, size)
- self.assertEqual(checksum, csum)
-
- def test_multithread_broadcast_pickle(self):
- import threading
-
- b1 = self.sc.broadcast(list(range(3)))
- b2 = self.sc.broadcast(list(range(3)))
-
- def f1():
- return b1.value
-
- def f2():
- return b2.value
-
- funcs_num_pickled = {f1: None, f2: None}
-
- def do_pickle(f, sc):
- command = (f, None, sc.serializer, sc.serializer)
- ser = CloudPickleSerializer()
- ser.dumps(command)
-
- def process_vars(sc):
- broadcast_vars = list(sc._pickled_broadcast_vars)
- num_pickled = len(broadcast_vars)
- sc._pickled_broadcast_vars.clear()
- return num_pickled
-
- def run(f, sc):
- do_pickle(f, sc)
- funcs_num_pickled[f] = process_vars(sc)
-
- # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage
- do_pickle(f1, self.sc)
-
- # run all for f2, should only add/count/clear b2 from worker thread local storage
- t = threading.Thread(target=run, args=(f2, self.sc))
- t.start()
- t.join()
-
- # count number of vars pickled in main thread, only b1 should be counted and cleared
- funcs_num_pickled[f1] = process_vars(self.sc)
-
- self.assertEqual(funcs_num_pickled[f1], 1)
- self.assertEqual(funcs_num_pickled[f2], 1)
- self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)
-
- def test_large_closure(self):
- N = 200000
- data = [float(i) for i in xrange(N)]
- rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
- self.assertEqual(N, rdd.first())
- # regression test for SPARK-6886
- self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
-
- def test_zip_with_different_serializers(self):
- a = self.sc.parallelize(range(5))
- b = self.sc.parallelize(range(100, 105))
- self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
- a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
- b = b._reserialize(MarshalSerializer())
- self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
- # regression test for SPARK-4841
- path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
- t = self.sc.textFile(path)
- cnt = t.count()
- self.assertEqual(cnt, t.zip(t).count())
- rdd = t.map(str)
- self.assertEqual(cnt, t.zip(rdd).count())
- # regression test for bug in _reserializer()
- self.assertEqual(cnt, t.zip(rdd).count())
-
- def test_zip_with_different_object_sizes(self):
- # regress test for SPARK-5973
- a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i)
- b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i)
- self.assertEqual(10000, a.zip(b).count())
-
- def test_zip_with_different_number_of_items(self):
- a = self.sc.parallelize(range(5), 2)
- # different number of partitions
- b = self.sc.parallelize(range(100, 106), 3)
- self.assertRaises(ValueError, lambda: a.zip(b))
- with QuietTest(self.sc):
- # different number of batched items in JVM
- b = self.sc.parallelize(range(100, 104), 2)
- self.assertRaises(Exception, lambda: a.zip(b).count())
- # different number of items in one pair
- b = self.sc.parallelize(range(100, 106), 2)
- self.assertRaises(Exception, lambda: a.zip(b).count())
- # same total number of items, but different distributions
- a = self.sc.parallelize([2, 3], 2).flatMap(range)
- b = self.sc.parallelize([3, 2], 2).flatMap(range)
- self.assertEqual(a.count(), b.count())
- self.assertRaises(Exception, lambda: a.zip(b).count())
-
- def test_count_approx_distinct(self):
- rdd = self.sc.parallelize(xrange(1000))
- self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
- self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
- self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
- self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050)
-
- rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
- self.assertTrue(18 < rdd.countApproxDistinct() < 22)
- self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22)
- self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22)
- self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22)
-
- self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001))
-
- def test_histogram(self):
- # empty
- rdd = self.sc.parallelize([])
- self.assertEqual([0], rdd.histogram([0, 10])[1])
- self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
- self.assertRaises(ValueError, lambda: rdd.histogram(1))
-
- # out of range
- rdd = self.sc.parallelize([10.01, -0.01])
- self.assertEqual([0], rdd.histogram([0, 10])[1])
- self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1])
-
- # in range with one bucket
- rdd = self.sc.parallelize(range(1, 5))
- self.assertEqual([4], rdd.histogram([0, 10])[1])
- self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1])
-
- # in range with one bucket exact match
- self.assertEqual([4], rdd.histogram([1, 4])[1])
-
- # out of range with two buckets
- rdd = self.sc.parallelize([10.01, -0.01])
- self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1])
-
- # out of range with two uneven buckets
- rdd = self.sc.parallelize([10.01, -0.01])
- self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
-
- # in range with two buckets
- rdd = self.sc.parallelize([1, 2, 3, 5, 6])
- self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
-
- # in range with two bucket and None
- rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')])
- self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
-
- # in range with two uneven buckets
- rdd = self.sc.parallelize([1, 2, 3, 5, 6])
- self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1])
-
- # mixed range with two uneven buckets
- rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01])
- self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1])
-
- # mixed range with four uneven buckets
- rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1])
- self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
-
- # mixed range with uneven buckets and NaN
- rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0,
- 199.0, 200.0, 200.1, None, float('nan')])
- self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
-
- # out of range with infinite buckets
- rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")])
- self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1])
-
- # invalid buckets
- self.assertRaises(ValueError, lambda: rdd.histogram([]))
- self.assertRaises(ValueError, lambda: rdd.histogram([1]))
- self.assertRaises(ValueError, lambda: rdd.histogram(0))
- self.assertRaises(TypeError, lambda: rdd.histogram({}))
-
- # without buckets
- rdd = self.sc.parallelize(range(1, 5))
- self.assertEqual(([1, 4], [4]), rdd.histogram(1))
-
- # without buckets single element
- rdd = self.sc.parallelize([1])
- self.assertEqual(([1, 1], [1]), rdd.histogram(1))
-
- # without bucket no range
- rdd = self.sc.parallelize([1] * 4)
- self.assertEqual(([1, 1], [4]), rdd.histogram(1))
-
- # without buckets basic two
- rdd = self.sc.parallelize(range(1, 5))
- self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2))
-
- # without buckets with more requested than elements
- rdd = self.sc.parallelize([1, 2])
- buckets = [1 + 0.2 * i for i in range(6)]
- hist = [1, 0, 0, 0, 1]
- self.assertEqual((buckets, hist), rdd.histogram(5))
-
- # invalid RDDs
- rdd = self.sc.parallelize([1, float('inf')])
- self.assertRaises(ValueError, lambda: rdd.histogram(2))
- rdd = self.sc.parallelize([float('nan')])
- self.assertRaises(ValueError, lambda: rdd.histogram(2))
-
- # string
- rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2)
- self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1])
- self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1))
- self.assertRaises(TypeError, lambda: rdd.histogram(2))
-
- def test_repartitionAndSortWithinPartitions_asc(self):
- rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
-
- repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True)
- partitions = repartitioned.glom().collect()
- self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)])
- self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)])
-
- def test_repartitionAndSortWithinPartitions_desc(self):
- rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
-
- repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False)
- partitions = repartitioned.glom().collect()
- self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)])
- self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)])
-
- def test_repartition_no_skewed(self):
- num_partitions = 20
- a = self.sc.parallelize(range(int(1000)), 2)
- l = a.repartition(num_partitions).glom().map(len).collect()
- zeros = len([x for x in l if x == 0])
- self.assertTrue(zeros == 0)
- l = a.coalesce(num_partitions, True).glom().map(len).collect()
- zeros = len([x for x in l if x == 0])
- self.assertTrue(zeros == 0)
-
- def test_repartition_on_textfile(self):
- path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
- rdd = self.sc.textFile(path)
- result = rdd.repartition(1).collect()
- self.assertEqual(u"Hello World!", result[0])
-
- def test_distinct(self):
- rdd = self.sc.parallelize((1, 2, 3)*10, 10)
- self.assertEqual(rdd.getNumPartitions(), 10)
- self.assertEqual(rdd.distinct().count(), 3)
- result = rdd.distinct(5)
- self.assertEqual(result.getNumPartitions(), 5)
- self.assertEqual(result.count(), 3)
-
- def test_external_group_by_key(self):
- self.sc._conf.set("spark.python.worker.memory", "1m")
- N = 200001
- kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
- gkv = kv.groupByKey().cache()
- self.assertEqual(3, gkv.count())
- filtered = gkv.filter(lambda kv: kv[0] == 1)
- self.assertEqual(1, filtered.count())
- self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect())
- self.assertEqual([(N // 3, N // 3)],
- filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
- result = filtered.collect()[0][1]
- self.assertEqual(N // 3, len(result))
- self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList))
-
- def test_sort_on_empty_rdd(self):
- self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
-
- def test_sample(self):
- rdd = self.sc.parallelize(range(0, 100), 4)
- wo = rdd.sample(False, 0.1, 2).collect()
- wo_dup = rdd.sample(False, 0.1, 2).collect()
- self.assertSetEqual(set(wo), set(wo_dup))
- wr = rdd.sample(True, 0.2, 5).collect()
- wr_dup = rdd.sample(True, 0.2, 5).collect()
- self.assertSetEqual(set(wr), set(wr_dup))
- wo_s10 = rdd.sample(False, 0.3, 10).collect()
- wo_s20 = rdd.sample(False, 0.3, 20).collect()
- self.assertNotEqual(set(wo_s10), set(wo_s20))
- wr_s11 = rdd.sample(True, 0.4, 11).collect()
- wr_s21 = rdd.sample(True, 0.4, 21).collect()
- self.assertNotEqual(set(wr_s11), set(wr_s21))
-
- def test_null_in_rdd(self):
- jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc)
- rdd = RDD(jrdd, self.sc, UTF8Deserializer())
- self.assertEqual([u"a", None, u"b"], rdd.collect())
- rdd = RDD(jrdd, self.sc, NoOpSerializer())
- self.assertEqual([b"a", None, b"b"], rdd.collect())
-
- def test_multiple_python_java_RDD_conversions(self):
- # Regression test for SPARK-5361
- data = [
- (u'1', {u'director': u'David Lean'}),
- (u'2', {u'director': u'Andrew Dominik'})
- ]
- data_rdd = self.sc.parallelize(data)
- data_java_rdd = data_rdd._to_java_object_rdd()
- data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd)
- converted_rdd = RDD(data_python_rdd, self.sc)
- self.assertEqual(2, converted_rdd.count())
-
- # conversion between python and java RDD threw exceptions
- data_java_rdd = converted_rdd._to_java_object_rdd()
- data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd)
- converted_rdd = RDD(data_python_rdd, self.sc)
- self.assertEqual(2, converted_rdd.count())
-
- def test_narrow_dependency_in_join(self):
- rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x))
- parted = rdd.partitionBy(2)
- self.assertEqual(2, parted.union(parted).getNumPartitions())
- self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
- self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
-
- tracker = self.sc.statusTracker()
-
- self.sc.setJobGroup("test1", "test", True)
- d = sorted(parted.join(parted).collect())
- self.assertEqual(10, len(d))
- self.assertEqual((0, (0, 0)), d[0])
- jobId = tracker.getJobIdsForGroup("test1")[0]
- self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
-
- self.sc.setJobGroup("test2", "test", True)
- d = sorted(parted.join(rdd).collect())
- self.assertEqual(10, len(d))
- self.assertEqual((0, (0, 0)), d[0])
- jobId = tracker.getJobIdsForGroup("test2")[0]
- self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
-
- self.sc.setJobGroup("test3", "test", True)
- d = sorted(parted.cogroup(parted).collect())
- self.assertEqual(10, len(d))
- self.assertEqual([[0], [0]], list(map(list, d[0][1])))
- jobId = tracker.getJobIdsForGroup("test3")[0]
- self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
-
- self.sc.setJobGroup("test4", "test", True)
- d = sorted(parted.cogroup(rdd).collect())
- self.assertEqual(10, len(d))
- self.assertEqual([[0], [0]], list(map(list, d[0][1])))
- jobId = tracker.getJobIdsForGroup("test4")[0]
- self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
-
- # Regression test for SPARK-6294
- def test_take_on_jrdd(self):
- rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x))
- rdd._jrdd.first()
-
- def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
- # Regression test for SPARK-5969
- seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence
- rdd = self.sc.parallelize(seq)
- for ascending in [True, False]:
- sort = rdd.sortByKey(ascending=ascending, numPartitions=5)
- self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending))
- sizes = sort.glom().map(len).collect()
- for size in sizes:
- self.assertGreater(size, 0)
-
- def test_pipe_functions(self):
- data = ['1', '2', '3']
- rdd = self.sc.parallelize(data)
- with QuietTest(self.sc):
- self.assertEqual([], rdd.pipe('cc').collect())
- self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect)
- result = rdd.pipe('cat').collect()
- result.sort()
- for x, y in zip(data, result):
- self.assertEqual(x, y)
- self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
- self.assertEqual([], rdd.pipe('grep 4').collect())
-
- def test_pipe_unicode(self):
- # Regression test for SPARK-20947
- data = [u'\u6d4b\u8bd5', '1']
- rdd = self.sc.parallelize(data)
- result = rdd.pipe('cat').collect()
- self.assertEqual(data, result)
-
- def test_stopiteration_in_user_code(self):
-
- def stopit(*x):
- raise StopIteration()
-
- seq_rdd = self.sc.parallelize(range(10))
- keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
- msg = "Caught StopIteration thrown from user's code; failing the task"
-
- self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
- self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
- self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
- self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
- self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
- self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
- self.assertRaisesRegexp(Py4JJavaError, msg,
- seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
-
- # these methods call the user function both in the driver and in the executor
- # the exception raised is different according to where the StopIteration happens
- # RuntimeError is raised if in the driver
- # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker)
- self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
- keyed_rdd.reduceByKeyLocally, stopit)
- self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
- seq_rdd.aggregate, 0, stopit, lambda *x: 1)
- self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
- seq_rdd.aggregate, 0, lambda *x: 1, stopit)
-
-
-class ProfilerTests(PySparkTestCase):
-
- def setUp(self):
- self._old_sys_path = list(sys.path)
- class_name = self.__class__.__name__
- conf = SparkConf().set("spark.python.profile", "true")
- self.sc = SparkContext('local[4]', class_name, conf=conf)
-
- def test_profiler(self):
- self.do_computation()
-
- profilers = self.sc.profiler_collector.profilers
- self.assertEqual(1, len(profilers))
- id, profiler, _ = profilers[0]
- stats = profiler.stats()
- self.assertTrue(stats is not None)
- width, stat_list = stats.get_print_list([])
- func_names = [func_name for fname, n, func_name in stat_list]
- self.assertTrue("heavy_foo" in func_names)
-
- old_stdout = sys.stdout
- sys.stdout = io = StringIO()
- self.sc.show_profiles()
- self.assertTrue("heavy_foo" in io.getvalue())
- sys.stdout = old_stdout
-
- d = tempfile.gettempdir()
- self.sc.dump_profiles(d)
- self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
-
- def test_custom_profiler(self):
- class TestCustomProfiler(BasicProfiler):
- def show(self, id):
- self.result = "Custom formatting"
-
- self.sc.profiler_collector.profiler_cls = TestCustomProfiler
-
- self.do_computation()
-
- profilers = self.sc.profiler_collector.profilers
- self.assertEqual(1, len(profilers))
- _, profiler, _ = profilers[0]
- self.assertTrue(isinstance(profiler, TestCustomProfiler))
-
- self.sc.show_profiles()
- self.assertEqual("Custom formatting", profiler.result)
-
- def do_computation(self):
- def heavy_foo(x):
- for i in range(1 << 18):
- x = 1
-
- rdd = self.sc.parallelize(range(100))
- rdd.foreach(heavy_foo)
-
-
-class ProfilerTests2(unittest.TestCase):
- def test_profiler_disabled(self):
- sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false"))
- try:
- self.assertRaisesRegexp(
- RuntimeError,
- "'spark.python.profile' configuration must be set",
- lambda: sc.show_profiles())
- self.assertRaisesRegexp(
- RuntimeError,
- "'spark.python.profile' configuration must be set",
- lambda: sc.dump_profiles("/tmp/abc"))
- finally:
- sc.stop()
-
-
-class InputFormatTests(ReusedPySparkTestCase):
-
- @classmethod
- def setUpClass(cls):
- ReusedPySparkTestCase.setUpClass()
- cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
- os.unlink(cls.tempdir.name)
- cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc)
-
- @classmethod
- def tearDownClass(cls):
- ReusedPySparkTestCase.tearDownClass()
- shutil.rmtree(cls.tempdir.name)
-
- @unittest.skipIf(sys.version >= "3", "serialize array of byte")
- def test_sequencefiles(self):
- basepath = self.tempdir.name
- ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text").collect())
- ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
- self.assertEqual(ints, ei)
-
- doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/",
- "org.apache.hadoop.io.DoubleWritable",
- "org.apache.hadoop.io.Text").collect())
- ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')]
- self.assertEqual(doubles, ed)
-
- bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.BytesWritable").collect())
- ebs = [(1, bytearray('aa', 'utf-8')),
- (1, bytearray('aa', 'utf-8')),
- (2, bytearray('aa', 'utf-8')),
- (2, bytearray('bb', 'utf-8')),
- (2, bytearray('bb', 'utf-8')),
- (3, bytearray('cc', 'utf-8'))]
- self.assertEqual(bytes, ebs)
-
- text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/",
- "org.apache.hadoop.io.Text",
- "org.apache.hadoop.io.Text").collect())
- et = [(u'1', u'aa'),
- (u'1', u'aa'),
- (u'2', u'aa'),
- (u'2', u'bb'),
- (u'2', u'bb'),
- (u'3', u'cc')]
- self.assertEqual(text, et)
-
- bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.BooleanWritable").collect())
- eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)]
- self.assertEqual(bools, eb)
-
- nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.BooleanWritable").collect())
- en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
- self.assertEqual(nulls, en)
-
- maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.MapWritable").collect()
- em = [(1, {}),
- (1, {3.0: u'bb'}),
- (2, {1.0: u'aa'}),
- (2, {1.0: u'cc'}),
- (3, {2.0: u'dd'})]
- for v in maps:
- self.assertTrue(v in em)
-
- # arrays get pickled to tuples by default
- tuples = sorted(self.sc.sequenceFile(
- basepath + "/sftestdata/sfarray/",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.spark.api.python.DoubleArrayWritable").collect())
- et = [(1, ()),
- (2, (3.0, 4.0, 5.0)),
- (3, (4.0, 5.0, 6.0))]
- self.assertEqual(tuples, et)
-
- # with custom converters, primitive arrays can stay as arrays
- arrays = sorted(self.sc.sequenceFile(
- basepath + "/sftestdata/sfarray/",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.spark.api.python.DoubleArrayWritable",
- valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect())
- ea = [(1, array('d')),
- (2, array('d', [3.0, 4.0, 5.0])),
- (3, array('d', [4.0, 5.0, 6.0]))]
- self.assertEqual(arrays, ea)
-
- clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
- "org.apache.hadoop.io.Text",
- "org.apache.spark.api.python.TestWritable").collect())
- cname = u'org.apache.spark.api.python.TestWritable'
- ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}),
- (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}),
- (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}),
- (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}),
- (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})]
- self.assertEqual(clazz, ec)
-
- unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
- "org.apache.hadoop.io.Text",
- "org.apache.spark.api.python.TestWritable",
- ).collect())
- self.assertEqual(unbatched_clazz, ec)
-
- def test_oldhadoop(self):
- basepath = self.tempdir.name
- ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/",
- "org.apache.hadoop.mapred.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text").collect())
- ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
- self.assertEqual(ints, ei)
-
- hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
- oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath}
- hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat",
- "org.apache.hadoop.io.LongWritable",
- "org.apache.hadoop.io.Text",
- conf=oldconf).collect()
- result = [(0, u'Hello World!')]
- self.assertEqual(hello, result)
-
- def test_newhadoop(self):
- basepath = self.tempdir.name
- ints = sorted(self.sc.newAPIHadoopFile(
- basepath + "/sftestdata/sfint/",
- "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text").collect())
- ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
- self.assertEqual(ints, ei)
-
- hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
- newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath}
- hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat",
- "org.apache.hadoop.io.LongWritable",
- "org.apache.hadoop.io.Text",
- conf=newconf).collect()
- result = [(0, u'Hello World!')]
- self.assertEqual(hello, result)
-
- def test_newolderror(self):
- basepath = self.tempdir.name
- self.assertRaises(Exception, lambda: self.sc.hadoopFile(
- basepath + "/sftestdata/sfint/",
- "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text"))
-
- self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile(
- basepath + "/sftestdata/sfint/",
- "org.apache.hadoop.mapred.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text"))
-
- def test_bad_inputs(self):
- basepath = self.tempdir.name
- self.assertRaises(Exception, lambda: self.sc.sequenceFile(
- basepath + "/sftestdata/sfint/",
- "org.apache.hadoop.io.NotValidWritable",
- "org.apache.hadoop.io.Text"))
- self.assertRaises(Exception, lambda: self.sc.hadoopFile(
- basepath + "/sftestdata/sfint/",
- "org.apache.hadoop.mapred.NotValidInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text"))
- self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile(
- basepath + "/sftestdata/sfint/",
- "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text"))
-
- def test_converters(self):
- # use of custom converters
- basepath = self.tempdir.name
- maps = sorted(self.sc.sequenceFile(
- basepath + "/sftestdata/sfmap/",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.MapWritable",
- keyConverter="org.apache.spark.api.python.TestInputKeyConverter",
- valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect())
- em = [(u'\x01', []),
- (u'\x01', [3.0]),
- (u'\x02', [1.0]),
- (u'\x02', [1.0]),
- (u'\x03', [2.0])]
- self.assertEqual(maps, em)
-
- def test_binary_files(self):
- path = os.path.join(self.tempdir.name, "binaryfiles")
- os.mkdir(path)
- data = b"short binary data"
- with open(os.path.join(path, "part-0000"), 'wb') as f:
- f.write(data)
- [(p, d)] = self.sc.binaryFiles(path).collect()
- self.assertTrue(p.endswith("part-0000"))
- self.assertEqual(d, data)
-
- def test_binary_records(self):
- path = os.path.join(self.tempdir.name, "binaryrecords")
- os.mkdir(path)
- with open(os.path.join(path, "part-0000"), 'w') as f:
- for i in range(100):
- f.write('%04d' % i)
- result = self.sc.binaryRecords(path, 4).map(int).collect()
- self.assertEqual(list(range(100)), result)
-
-
-class OutputFormatTests(ReusedPySparkTestCase):
-
- def setUp(self):
- self.tempdir = tempfile.NamedTemporaryFile(delete=False)
- os.unlink(self.tempdir.name)
-
- def tearDown(self):
- shutil.rmtree(self.tempdir.name, ignore_errors=True)
-
- @unittest.skipIf(sys.version >= "3", "serialize array of byte")
- def test_sequencefiles(self):
- basepath = self.tempdir.name
- ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
- self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/")
- ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect())
- self.assertEqual(ints, ei)
-
- ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')]
- self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/")
- doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect())
- self.assertEqual(doubles, ed)
-
- ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))]
- self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/")
- bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect())
- self.assertEqual(bytes, ebs)
-
- et = [(u'1', u'aa'),
- (u'2', u'bb'),
- (u'3', u'cc')]
- self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/")
- text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect())
- self.assertEqual(text, et)
-
- eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)]
- self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/")
- bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect())
- self.assertEqual(bools, eb)
-
- en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
- self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/")
- nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect())
- self.assertEqual(nulls, en)
-
- em = [(1, {}),
- (1, {3.0: u'bb'}),
- (2, {1.0: u'aa'}),
- (2, {1.0: u'cc'}),
- (3, {2.0: u'dd'})]
- self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/")
- maps = self.sc.sequenceFile(basepath + "/sfmap/").collect()
- for v in maps:
- self.assertTrue(v, em)
-
- def test_oldhadoop(self):
- basepath = self.tempdir.name
- dict_data = [(1, {}),
- (1, {"row1": 1.0}),
- (2, {"row2": 2.0})]
- self.sc.parallelize(dict_data).saveAsHadoopFile(
- basepath + "/oldhadoop/",
- "org.apache.hadoop.mapred.SequenceFileOutputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.MapWritable")
- result = self.sc.hadoopFile(
- basepath + "/oldhadoop/",
- "org.apache.hadoop.mapred.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.MapWritable").collect()
- for v in result:
- self.assertTrue(v, dict_data)
-
- conf = {
- "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat",
- "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable",
- "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable",
- "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/"
- }
- self.sc.parallelize(dict_data).saveAsHadoopDataset(conf)
- input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"}
- result = self.sc.hadoopRDD(
- "org.apache.hadoop.mapred.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.MapWritable",
- conf=input_conf).collect()
- for v in result:
- self.assertTrue(v, dict_data)
-
- def test_newhadoop(self):
- basepath = self.tempdir.name
- data = [(1, ""),
- (1, "a"),
- (2, "bcdf")]
- self.sc.parallelize(data).saveAsNewAPIHadoopFile(
- basepath + "/newhadoop/",
- "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text")
- result = sorted(self.sc.newAPIHadoopFile(
- basepath + "/newhadoop/",
- "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text").collect())
- self.assertEqual(result, data)
-
- conf = {
- "mapreduce.job.outputformat.class":
- "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
- "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable",
- "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text",
- "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/"
- }
- self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf)
- input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"}
- new_dataset = sorted(self.sc.newAPIHadoopRDD(
- "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text",
- conf=input_conf).collect())
- self.assertEqual(new_dataset, data)
-
- @unittest.skipIf(sys.version >= "3", "serialize of array")
- def test_newhadoop_with_array(self):
- basepath = self.tempdir.name
- # use custom ArrayWritable types and converters to handle arrays
- array_data = [(1, array('d')),
- (1, array('d', [1.0, 2.0, 3.0])),
- (2, array('d', [3.0, 4.0, 5.0]))]
- self.sc.parallelize(array_data).saveAsNewAPIHadoopFile(
- basepath + "/newhadoop/",
- "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.spark.api.python.DoubleArrayWritable",
- valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter")
- result = sorted(self.sc.newAPIHadoopFile(
- basepath + "/newhadoop/",
- "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.spark.api.python.DoubleArrayWritable",
- valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect())
- self.assertEqual(result, array_data)
-
- conf = {
- "mapreduce.job.outputformat.class":
- "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
- "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable",
- "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable",
- "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/"
- }
- self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset(
- conf,
- valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter")
- input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"}
- new_dataset = sorted(self.sc.newAPIHadoopRDD(
- "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.spark.api.python.DoubleArrayWritable",
- valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter",
- conf=input_conf).collect())
- self.assertEqual(new_dataset, array_data)
-
- def test_newolderror(self):
- basepath = self.tempdir.name
- rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x))
- self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile(
- basepath + "/newolderror/saveAsHadoopFile/",
- "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat"))
- self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile(
- basepath + "/newolderror/saveAsNewAPIHadoopFile/",
- "org.apache.hadoop.mapred.SequenceFileOutputFormat"))
-
- def test_bad_inputs(self):
- basepath = self.tempdir.name
- rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x))
- self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile(
- basepath + "/badinputs/saveAsHadoopFile/",
- "org.apache.hadoop.mapred.NotValidOutputFormat"))
- self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile(
- basepath + "/badinputs/saveAsNewAPIHadoopFile/",
- "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat"))
-
- def test_converters(self):
- # use of custom converters
- basepath = self.tempdir.name
- data = [(1, {3.0: u'bb'}),
- (2, {1.0: u'aa'}),
- (3, {2.0: u'dd'})]
- self.sc.parallelize(data).saveAsNewAPIHadoopFile(
- basepath + "/converters/",
- "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
- keyConverter="org.apache.spark.api.python.TestOutputKeyConverter",
- valueConverter="org.apache.spark.api.python.TestOutputValueConverter")
- converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect())
- expected = [(u'1', 3.0),
- (u'2', 1.0),
- (u'3', 2.0)]
- self.assertEqual(converted, expected)
-
- def test_reserialization(self):
- basepath = self.tempdir.name
- x = range(1, 5)
- y = range(1001, 1005)
- data = list(zip(x, y))
- rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y))
- rdd.saveAsSequenceFile(basepath + "/reserialize/sequence")
- result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect())
- self.assertEqual(result1, data)
-
- rdd.saveAsHadoopFile(
- basepath + "/reserialize/hadoop",
- "org.apache.hadoop.mapred.SequenceFileOutputFormat")
- result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect())
- self.assertEqual(result2, data)
-
- rdd.saveAsNewAPIHadoopFile(
- basepath + "/reserialize/newhadoop",
- "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")
- result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect())
- self.assertEqual(result3, data)
-
- conf4 = {
- "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat",
- "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable",
- "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable",
- "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"}
- rdd.saveAsHadoopDataset(conf4)
- result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect())
- self.assertEqual(result4, data)
-
- conf5 = {"mapreduce.job.outputformat.class":
- "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
- "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable",
- "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable",
- "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset"
- }
- rdd.saveAsNewAPIHadoopDataset(conf5)
- result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect())
- self.assertEqual(result5, data)
-
- def test_malformed_RDD(self):
- basepath = self.tempdir.name
- # non-batch-serialized RDD[[(K, V)]] should be rejected
- data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]]
- rdd = self.sc.parallelize(data, len(data))
- self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile(
- basepath + "/malformed/sequence"))
-
-
-class DaemonTests(unittest.TestCase):
- def connect(self, port):
- from socket import socket, AF_INET, SOCK_STREAM
- sock = socket(AF_INET, SOCK_STREAM)
- sock.connect(('127.0.0.1', port))
- # send a split index of -1 to shutdown the worker
- sock.send(b"\xFF\xFF\xFF\xFF")
- sock.close()
- return True
-
- def do_termination_test(self, terminator):
- from subprocess import Popen, PIPE
- from errno import ECONNREFUSED
-
- # start daemon
- daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
- python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON")
- daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE)
-
- # read the port number
- port = read_int(daemon.stdout)
-
- # daemon should accept connections
- self.assertTrue(self.connect(port))
-
- # request shutdown
- terminator(daemon)
- time.sleep(1)
-
- # daemon should no longer accept connections
- try:
- self.connect(port)
- except EnvironmentError as exception:
- self.assertEqual(exception.errno, ECONNREFUSED)
- else:
- self.fail("Expected EnvironmentError to be raised")
-
- def test_termination_stdin(self):
- """Ensure that daemon and workers terminate when stdin is closed."""
- self.do_termination_test(lambda daemon: daemon.stdin.close())
-
- def test_termination_sigterm(self):
- """Ensure that daemon and workers terminate on SIGTERM."""
- from signal import SIGTERM
- self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
-
-
-class WorkerTests(ReusedPySparkTestCase):
- def test_cancel_task(self):
- temp = tempfile.NamedTemporaryFile(delete=True)
- temp.close()
- path = temp.name
-
- def sleep(x):
- import os
- import time
- with open(path, 'w') as f:
- f.write("%d %d" % (os.getppid(), os.getpid()))
- time.sleep(100)
-
- # start job in background thread
- def run():
- try:
- self.sc.parallelize(range(1), 1).foreach(sleep)
- except Exception:
- pass
- import threading
- t = threading.Thread(target=run)
- t.daemon = True
- t.start()
-
- daemon_pid, worker_pid = 0, 0
- while True:
- if os.path.exists(path):
- with open(path) as f:
- data = f.read().split(' ')
- daemon_pid, worker_pid = map(int, data)
- break
- time.sleep(0.1)
-
- # cancel jobs
- self.sc.cancelAllJobs()
- t.join()
-
- for i in range(50):
- try:
- os.kill(worker_pid, 0)
- time.sleep(0.1)
- except OSError:
- break # worker was killed
- else:
- self.fail("worker has not been killed after 5 seconds")
-
- try:
- os.kill(daemon_pid, 0)
- except OSError:
- self.fail("daemon had been killed")
-
- # run a normal job
- rdd = self.sc.parallelize(xrange(100), 1)
- self.assertEqual(100, rdd.map(str).count())
-
- def test_after_exception(self):
- def raise_exception(_):
- raise Exception()
- rdd = self.sc.parallelize(xrange(100), 1)
- with QuietTest(self.sc):
- self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
- self.assertEqual(100, rdd.map(str).count())
-
- def test_after_jvm_exception(self):
- tempFile = tempfile.NamedTemporaryFile(delete=False)
- tempFile.write(b"Hello World!")
- tempFile.close()
- data = self.sc.textFile(tempFile.name, 1)
- filtered_data = data.filter(lambda x: True)
- self.assertEqual(1, filtered_data.count())
- os.unlink(tempFile.name)
- with QuietTest(self.sc):
- self.assertRaises(Exception, lambda: filtered_data.count())
-
- rdd = self.sc.parallelize(xrange(100), 1)
- self.assertEqual(100, rdd.map(str).count())
-
- def test_accumulator_when_reuse_worker(self):
- from pyspark.accumulators import INT_ACCUMULATOR_PARAM
- acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
- self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
- self.assertEqual(sum(range(100)), acc1.value)
-
- acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
- self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
- self.assertEqual(sum(range(100)), acc2.value)
- self.assertEqual(sum(range(100)), acc1.value)
-
- def test_reuse_worker_after_take(self):
- rdd = self.sc.parallelize(xrange(100000), 1)
- self.assertEqual(0, rdd.first())
-
- def count():
- try:
- rdd.count()
- except Exception:
- pass
-
- t = threading.Thread(target=count)
- t.daemon = True
- t.start()
- t.join(5)
- self.assertTrue(not t.isAlive())
- self.assertEqual(100000, rdd.count())
-
- def test_with_different_versions_of_python(self):
- rdd = self.sc.parallelize(range(10))
- rdd.count()
- version = self.sc.pythonVer
- self.sc.pythonVer = "2.0"
- try:
- with QuietTest(self.sc):
- self.assertRaises(Py4JJavaError, lambda: rdd.count())
- finally:
- self.sc.pythonVer = version
-
-
-class SparkSubmitTests(unittest.TestCase):
-
- def setUp(self):
- self.programDir = tempfile.mkdtemp()
- tmp_dir = tempfile.gettempdir()
- self.sparkSubmit = [
- os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"),
- "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
- "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
- ]
-
- def tearDown(self):
- shutil.rmtree(self.programDir)
-
- def createTempFile(self, name, content, dir=None):
- """
- Create a temp file with the given name and content and return its path.
- Strips leading spaces from content up to the first '|' in each line.
- """
- pattern = re.compile(r'^ *\|', re.MULTILINE)
- content = re.sub(pattern, '', content.strip())
- if dir is None:
- path = os.path.join(self.programDir, name)
- else:
- os.makedirs(os.path.join(self.programDir, dir))
- path = os.path.join(self.programDir, dir, name)
- with open(path, "w") as f:
- f.write(content)
- return path
-
- def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None):
- """
- Create a zip archive containing a file with the given content and return its path.
- Strips leading spaces from content up to the first '|' in each line.
- """
- pattern = re.compile(r'^ *\|', re.MULTILINE)
- content = re.sub(pattern, '', content.strip())
- if dir is None:
- path = os.path.join(self.programDir, name + ext)
- else:
- path = os.path.join(self.programDir, dir, zip_name + ext)
- zip = zipfile.ZipFile(path, 'w')
- zip.writestr(name, content)
- zip.close()
- return path
-
- def create_spark_package(self, artifact_name):
- group_id, artifact_id, version = artifact_name.split(":")
- self.createTempFile("%s-%s.pom" % (artifact_id, version), ("""
- |
- |
- | 4.0.0
- | %s
- | %s
- | %s
- |
- """ % (group_id, artifact_id, version)).lstrip(),
- os.path.join(group_id, artifact_id, version))
- self.createFileInZip("%s.py" % artifact_id, """
- |def myfunc(x):
- | return x + 1
- """, ".jar", os.path.join(group_id, artifact_id, version),
- "%s-%s" % (artifact_id, version))
-
- def test_single_script(self):
- """Submit and test a single script file"""
- script = self.createTempFile("test.py", """
- |from pyspark import SparkContext
- |
- |sc = SparkContext()
- |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect())
- """)
- proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE)
- out, err = proc.communicate()
- self.assertEqual(0, proc.returncode)
- self.assertIn("[2, 4, 6]", out.decode('utf-8'))
-
- def test_script_with_local_functions(self):
- """Submit and test a single script file calling a global function"""
- script = self.createTempFile("test.py", """
- |from pyspark import SparkContext
- |
- |def foo(x):
- | return x * 3
- |
- |sc = SparkContext()
- |print(sc.parallelize([1, 2, 3]).map(foo).collect())
- """)
- proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE)
- out, err = proc.communicate()
- self.assertEqual(0, proc.returncode)
- self.assertIn("[3, 6, 9]", out.decode('utf-8'))
-
- def test_module_dependency(self):
- """Submit and test a script with a dependency on another module"""
- script = self.createTempFile("test.py", """
- |from pyspark import SparkContext
- |from mylib import myfunc
- |
- |sc = SparkContext()
- |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
- """)
- zip = self.createFileInZip("mylib.py", """
- |def myfunc(x):
- | return x + 1
- """)
- proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script],
- stdout=subprocess.PIPE)
- out, err = proc.communicate()
- self.assertEqual(0, proc.returncode)
- self.assertIn("[2, 3, 4]", out.decode('utf-8'))
-
- def test_module_dependency_on_cluster(self):
- """Submit and test a script with a dependency on another module on a cluster"""
- script = self.createTempFile("test.py", """
- |from pyspark import SparkContext
- |from mylib import myfunc
- |
- |sc = SparkContext()
- |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
- """)
- zip = self.createFileInZip("mylib.py", """
- |def myfunc(x):
- | return x + 1
- """)
- proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master",
- "local-cluster[1,1,1024]", script],
- stdout=subprocess.PIPE)
- out, err = proc.communicate()
- self.assertEqual(0, proc.returncode)
- self.assertIn("[2, 3, 4]", out.decode('utf-8'))
-
- def test_package_dependency(self):
- """Submit and test a script with a dependency on a Spark Package"""
- script = self.createTempFile("test.py", """
- |from pyspark import SparkContext
- |from mylib import myfunc
- |
- |sc = SparkContext()
- |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
- """)
- self.create_spark_package("a:mylib:0.1")
- proc = subprocess.Popen(
- self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories",
- "file:" + self.programDir, script],
- stdout=subprocess.PIPE)
- out, err = proc.communicate()
- self.assertEqual(0, proc.returncode)
- self.assertIn("[2, 3, 4]", out.decode('utf-8'))
-
- def test_package_dependency_on_cluster(self):
- """Submit and test a script with a dependency on a Spark Package on a cluster"""
- script = self.createTempFile("test.py", """
- |from pyspark import SparkContext
- |from mylib import myfunc
- |
- |sc = SparkContext()
- |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
- """)
- self.create_spark_package("a:mylib:0.1")
- proc = subprocess.Popen(
- self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories",
- "file:" + self.programDir, "--master", "local-cluster[1,1,1024]",
- script],
- stdout=subprocess.PIPE)
- out, err = proc.communicate()
- self.assertEqual(0, proc.returncode)
- self.assertIn("[2, 3, 4]", out.decode('utf-8'))
-
- def test_single_script_on_cluster(self):
- """Submit and test a single script on a cluster"""
- script = self.createTempFile("test.py", """
- |from pyspark import SparkContext
- |
- |def foo(x):
- | return x * 2
- |
- |sc = SparkContext()
- |print(sc.parallelize([1, 2, 3]).map(foo).collect())
- """)
- # this will fail if you have different spark.executor.memory
- # in conf/spark-defaults.conf
- proc = subprocess.Popen(
- self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script],
- stdout=subprocess.PIPE)
- out, err = proc.communicate()
- self.assertEqual(0, proc.returncode)
- self.assertIn("[2, 4, 6]", out.decode('utf-8'))
-
- def test_user_configuration(self):
- """Make sure user configuration is respected (SPARK-19307)"""
- script = self.createTempFile("test.py", """
- |from pyspark import SparkConf, SparkContext
- |
- |conf = SparkConf().set("spark.test_config", "1")
- |sc = SparkContext(conf = conf)
- |try:
- | if sc._conf.get("spark.test_config") != "1":
- | raise Exception("Cannot find spark.test_config in SparkContext's conf.")
- |finally:
- | sc.stop()
- """)
- proc = subprocess.Popen(
- self.sparkSubmit + ["--master", "local", script],
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT)
- out, err = proc.communicate()
- self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out))
-
- def test_conda(self):
- """Submit and test a single script file via conda"""
- script = self.createTempFile("test.py", """
- |from pyspark import SparkContext
- |
- |sc = SparkContext()
- |sc.addCondaPackages('numpy=1.14.0')
- |
- |# Ensure numpy is accessible on the driver
- |import numpy
- |arr = [1, 2, 3]
- |def mul2(x):
- | # Also ensure numpy accessible from executor
- | assert numpy.version.version == "1.14.0"
- | return x * 2
- |print(sc.parallelize(arr).map(mul2).collect())
- """)
- props = self.createTempFile("properties", """
- |spark.conda.binaryPath {}
- |spark.conda.channelUrls https://repo.continuum.io/pkgs/main
- |spark.conda.bootstrapPackages python=3.5
- """.format(os.environ["CONDA_BIN"]))
- env = dict(os.environ)
- del env['PYSPARK_PYTHON']
- del env['PYSPARK_DRIVER_PYTHON']
- proc = subprocess.Popen(self.sparkSubmit + [
- "--properties-file", props, script],
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- env=env)
- out, err = proc.communicate()
- if 0 != proc.returncode:
- self.fail(("spark-submit was unsuccessful with error code {}\n\n" +
- "stdout:\n{}\n\nstderr:\n{}").format(proc.returncode, out, err))
- self.assertIn("[2, 4, 6]", out.decode('utf-8'))
-
-
-class ContextTests(unittest.TestCase):
-
- def test_failed_sparkcontext_creation(self):
- # Regression test for SPARK-1550
- self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
-
- def test_get_or_create(self):
- with SparkContext.getOrCreate() as sc:
- self.assertTrue(SparkContext.getOrCreate() is sc)
-
- def test_parallelize_eager_cleanup(self):
- with SparkContext() as sc:
- temp_files = os.listdir(sc._temp_dir)
- rdd = sc.parallelize([0, 1, 2])
- post_parallalize_temp_files = os.listdir(sc._temp_dir)
- self.assertEqual(temp_files, post_parallalize_temp_files)
-
- def test_set_conf(self):
- # This is for an internal use case. When there is an existing SparkContext,
- # SparkSession's builder needs to set configs into SparkContext's conf.
- sc = SparkContext()
- sc._conf.set("spark.test.SPARK16224", "SPARK16224")
- self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224")
- sc.stop()
-
- def test_stop(self):
- sc = SparkContext()
- self.assertNotEqual(SparkContext._active_spark_context, None)
- sc.stop()
- self.assertEqual(SparkContext._active_spark_context, None)
-
- def test_with(self):
- with SparkContext() as sc:
- self.assertNotEqual(SparkContext._active_spark_context, None)
- self.assertEqual(SparkContext._active_spark_context, None)
-
- def test_with_exception(self):
- try:
- with SparkContext() as sc:
- self.assertNotEqual(SparkContext._active_spark_context, None)
- raise Exception()
- except:
- pass
- self.assertEqual(SparkContext._active_spark_context, None)
-
- def test_with_stop(self):
- with SparkContext() as sc:
- self.assertNotEqual(SparkContext._active_spark_context, None)
- sc.stop()
- self.assertEqual(SparkContext._active_spark_context, None)
-
- def test_progress_api(self):
- with SparkContext() as sc:
- sc.setJobGroup('test_progress_api', '', True)
- rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
-
- def run():
- try:
- rdd.count()
- except Exception:
- pass
- t = threading.Thread(target=run)
- t.daemon = True
- t.start()
- # wait for scheduler to start
- time.sleep(1)
-
- tracker = sc.statusTracker()
- jobIds = tracker.getJobIdsForGroup('test_progress_api')
- self.assertEqual(1, len(jobIds))
- job = tracker.getJobInfo(jobIds[0])
- self.assertEqual(1, len(job.stageIds))
- stage = tracker.getStageInfo(job.stageIds[0])
- self.assertEqual(rdd.getNumPartitions(), stage.numTasks)
-
- sc.cancelAllJobs()
- t.join()
- # wait for event listener to update the status
- time.sleep(1)
-
- job = tracker.getJobInfo(jobIds[0])
- self.assertEqual('FAILED', job.status)
- self.assertEqual([], tracker.getActiveJobsIds())
- self.assertEqual([], tracker.getActiveStageIds())
-
- sc.stop()
-
- def test_startTime(self):
- with SparkContext() as sc:
- self.assertGreater(sc.startTime, 0)
-
-
-class ConfTests(unittest.TestCase):
- def test_memory_conf(self):
- memoryList = ["1T", "1G", "1M", "1024K"]
- for memory in memoryList:
- sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory))
- l = list(range(1024))
- random.shuffle(l)
- rdd = sc.parallelize(l, 4)
- self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
- sc.stop()
-
-
-class KeywordOnlyTests(unittest.TestCase):
- class Wrapped(object):
- @keyword_only
- def set(self, x=None, y=None):
- if "x" in self._input_kwargs:
- self._x = self._input_kwargs["x"]
- if "y" in self._input_kwargs:
- self._y = self._input_kwargs["y"]
- return x, y
-
- def test_keywords(self):
- w = self.Wrapped()
- x, y = w.set(y=1)
- self.assertEqual(y, 1)
- self.assertEqual(y, w._y)
- self.assertIsNone(x)
- self.assertFalse(hasattr(w, "_x"))
-
- def test_non_keywords(self):
- w = self.Wrapped()
- self.assertRaises(TypeError, lambda: w.set(0, y=1))
-
- def test_kwarg_ownership(self):
- # test _input_kwargs is owned by each class instance and not a shared static variable
- class Setter(object):
- @keyword_only
- def set(self, x=None, other=None, other_x=None):
- if "other" in self._input_kwargs:
- self._input_kwargs["other"].set(x=self._input_kwargs["other_x"])
- self._x = self._input_kwargs["x"]
-
- a = Setter()
- b = Setter()
- a.set(x=1, other=b, other_x=2)
- self.assertEqual(a._x, 1)
- self.assertEqual(b._x, 2)
-
-
-class UtilTests(PySparkTestCase):
- def test_py4j_exception_message(self):
- from pyspark.util import _exception_message
-
- with self.assertRaises(Py4JJavaError) as context:
- # This attempts java.lang.String(null) which throws an NPE.
- self.sc._jvm.java.lang.String(None)
-
- self.assertTrue('NullPointerException' in _exception_message(context.exception))
-
- def test_parsing_version_string(self):
- from pyspark.util import VersionUtils
- self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced"))
-
-
-@unittest.skipIf(not _have_scipy, "SciPy not installed")
-class SciPyTests(PySparkTestCase):
-
- """General PySpark tests that depend on scipy """
-
- def test_serialize(self):
- from scipy.special import gammaln
- x = range(1, 5)
- expected = list(map(gammaln, x))
- observed = self.sc.parallelize(x).map(gammaln).collect()
- self.assertEqual(expected, observed)
-
-
-@unittest.skipIf(not _have_numpy, "NumPy not installed")
-class NumPyTests(PySparkTestCase):
-
- """General PySpark tests that depend on numpy """
-
- def test_statcounter_array(self):
- x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])])
- s = x.stats()
- self.assertSequenceEqual([2.0, 2.0], s.mean().tolist())
- self.assertSequenceEqual([1.0, 1.0], s.min().tolist())
- self.assertSequenceEqual([3.0, 3.0], s.max().tolist())
- self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist())
-
- stats_dict = s.asDict()
- self.assertEqual(3, stats_dict['count'])
- self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist())
- self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist())
- self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist())
- self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist())
- self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist())
- self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist())
-
- stats_sample_dict = s.asDict(sample=True)
- self.assertEqual(3, stats_dict['count'])
- self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist())
- self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist())
- self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist())
- self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist())
- self.assertSequenceEqual(
- [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist())
- self.assertSequenceEqual(
- [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist())
-
-
-if __name__ == "__main__":
- from pyspark.tests import *
- runner = unishark.BufferedTestRunner(
- reporters=[unishark.XUnitReporter('target/test-reports/pyspark_{}'.format(
- os.path.basename(os.environ.get("PYSPARK_PYTHON", ""))))])
- unittest.main(testRunner=runner, verbosity=2)
diff --git a/python/pyspark/tests/__init__.py b/python/pyspark/tests/__init__.py
new file mode 100644
index 0000000000000..12bdf0d0175b6
--- /dev/null
+++ b/python/pyspark/tests/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py
new file mode 100644
index 0000000000000..92bcb11561307
--- /dev/null
+++ b/python/pyspark/tests/test_appsubmit.py
@@ -0,0 +1,248 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import re
+import shutil
+import subprocess
+import tempfile
+import unittest
+import zipfile
+
+
+class SparkSubmitTests(unittest.TestCase):
+
+ def setUp(self):
+ self.programDir = tempfile.mkdtemp()
+ tmp_dir = tempfile.gettempdir()
+ self.sparkSubmit = [
+ os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit"),
+ "--conf", "spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
+ "--conf", "spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
+ ]
+
+ def tearDown(self):
+ shutil.rmtree(self.programDir)
+
+ def createTempFile(self, name, content, dir=None):
+ """
+ Create a temp file with the given name and content and return its path.
+ Strips leading spaces from content up to the first '|' in each line.
+ """
+ pattern = re.compile(r'^ *\|', re.MULTILINE)
+ content = re.sub(pattern, '', content.strip())
+ if dir is None:
+ path = os.path.join(self.programDir, name)
+ else:
+ os.makedirs(os.path.join(self.programDir, dir))
+ path = os.path.join(self.programDir, dir, name)
+ with open(path, "w") as f:
+ f.write(content)
+ return path
+
+ def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None):
+ """
+ Create a zip archive containing a file with the given content and return its path.
+ Strips leading spaces from content up to the first '|' in each line.
+ """
+ pattern = re.compile(r'^ *\|', re.MULTILINE)
+ content = re.sub(pattern, '', content.strip())
+ if dir is None:
+ path = os.path.join(self.programDir, name + ext)
+ else:
+ path = os.path.join(self.programDir, dir, zip_name + ext)
+ zip = zipfile.ZipFile(path, 'w')
+ zip.writestr(name, content)
+ zip.close()
+ return path
+
+ def create_spark_package(self, artifact_name):
+ group_id, artifact_id, version = artifact_name.split(":")
+ self.createTempFile("%s-%s.pom" % (artifact_id, version), ("""
+ |
+ |
+ | 4.0.0
+ | %s
+ | %s
+ | %s
+ |
+ """ % (group_id, artifact_id, version)).lstrip(),
+ os.path.join(group_id, artifact_id, version))
+ self.createFileInZip("%s.py" % artifact_id, """
+ |def myfunc(x):
+ | return x + 1
+ """, ".jar", os.path.join(group_id, artifact_id, version),
+ "%s-%s" % (artifact_id, version))
+
+ def test_single_script(self):
+ """Submit and test a single script file"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkContext
+ |
+ |sc = SparkContext()
+ |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect())
+ """)
+ proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode)
+ self.assertIn("[2, 4, 6]", out.decode('utf-8'))
+
+ def test_script_with_local_functions(self):
+ """Submit and test a single script file calling a global function"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkContext
+ |
+ |def foo(x):
+ | return x * 3
+ |
+ |sc = SparkContext()
+ |print(sc.parallelize([1, 2, 3]).map(foo).collect())
+ """)
+ proc = subprocess.Popen(self.sparkSubmit + [script], stdout=subprocess.PIPE)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode)
+ self.assertIn("[3, 6, 9]", out.decode('utf-8'))
+
+ def test_module_dependency(self):
+ """Submit and test a script with a dependency on another module"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkContext
+ |from mylib import myfunc
+ |
+ |sc = SparkContext()
+ |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
+ """)
+ zip = self.createFileInZip("mylib.py", """
+ |def myfunc(x):
+ | return x + 1
+ """)
+ proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, script],
+ stdout=subprocess.PIPE)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode)
+ self.assertIn("[2, 3, 4]", out.decode('utf-8'))
+
+ def test_module_dependency_on_cluster(self):
+ """Submit and test a script with a dependency on another module on a cluster"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkContext
+ |from mylib import myfunc
+ |
+ |sc = SparkContext()
+ |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
+ """)
+ zip = self.createFileInZip("mylib.py", """
+ |def myfunc(x):
+ | return x + 1
+ """)
+ proc = subprocess.Popen(self.sparkSubmit + ["--py-files", zip, "--master",
+ "local-cluster[1,1,1024]", script],
+ stdout=subprocess.PIPE)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode)
+ self.assertIn("[2, 3, 4]", out.decode('utf-8'))
+
+ def test_package_dependency(self):
+ """Submit and test a script with a dependency on a Spark Package"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkContext
+ |from mylib import myfunc
+ |
+ |sc = SparkContext()
+ |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
+ """)
+ self.create_spark_package("a:mylib:0.1")
+ proc = subprocess.Popen(
+ self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories",
+ "file:" + self.programDir, script],
+ stdout=subprocess.PIPE)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode)
+ self.assertIn("[2, 3, 4]", out.decode('utf-8'))
+
+ def test_package_dependency_on_cluster(self):
+ """Submit and test a script with a dependency on a Spark Package on a cluster"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkContext
+ |from mylib import myfunc
+ |
+ |sc = SparkContext()
+ |print(sc.parallelize([1, 2, 3]).map(myfunc).collect())
+ """)
+ self.create_spark_package("a:mylib:0.1")
+ proc = subprocess.Popen(
+ self.sparkSubmit + ["--packages", "a:mylib:0.1", "--repositories",
+ "file:" + self.programDir, "--master", "local-cluster[1,1,1024]",
+ script],
+ stdout=subprocess.PIPE)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode)
+ self.assertIn("[2, 3, 4]", out.decode('utf-8'))
+
+ def test_single_script_on_cluster(self):
+ """Submit and test a single script on a cluster"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkContext
+ |
+ |def foo(x):
+ | return x * 2
+ |
+ |sc = SparkContext()
+ |print(sc.parallelize([1, 2, 3]).map(foo).collect())
+ """)
+ # this will fail if you have different spark.executor.memory
+ # in conf/spark-defaults.conf
+ proc = subprocess.Popen(
+ self.sparkSubmit + ["--master", "local-cluster[1,1,1024]", script],
+ stdout=subprocess.PIPE)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode)
+ self.assertIn("[2, 4, 6]", out.decode('utf-8'))
+
+ def test_user_configuration(self):
+ """Make sure user configuration is respected (SPARK-19307)"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkConf, SparkContext
+ |
+ |conf = SparkConf().set("spark.test_config", "1")
+ |sc = SparkContext(conf = conf)
+ |try:
+ | if sc._conf.get("spark.test_config") != "1":
+ | raise Exception("Cannot find spark.test_config in SparkContext's conf.")
+ |finally:
+ | sc.stop()
+ """)
+ proc = subprocess.Popen(
+ self.sparkSubmit + ["--master", "local", script],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode, msg="Process failed with error:\n {0}".format(out))
+
+
+if __name__ == "__main__":
+ from pyspark.tests.test_appsubmit import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/tests/test_broadcast.py
similarity index 91%
rename from python/pyspark/test_broadcast.py
rename to python/pyspark/tests/test_broadcast.py
index a00329c18ad8f..a98626e8f4bc9 100644
--- a/python/pyspark/test_broadcast.py
+++ b/python/pyspark/tests/test_broadcast.py
@@ -14,20 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
import os
import random
import tempfile
import unittest
-try:
- import xmlrunner
-except ImportError:
- xmlrunner = None
-
-from pyspark.broadcast import Broadcast
-from pyspark.conf import SparkConf
-from pyspark.context import SparkContext
+from pyspark import SparkConf, SparkContext
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import ChunkedStream
@@ -118,9 +110,13 @@ def random_bytes(n):
for buffer_length in [1, 2, 5, 8192]:
self._test_chunked_stream(random_bytes(data_length), buffer_length)
+
if __name__ == '__main__':
- from pyspark.test_broadcast import *
- if xmlrunner:
- unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
- else:
- unittest.main(verbosity=2)
+ from pyspark.tests.test_broadcast import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_conf.py b/python/pyspark/tests/test_conf.py
new file mode 100644
index 0000000000000..f5a9accc3fe6e
--- /dev/null
+++ b/python/pyspark/tests/test_conf.py
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import random
+import unittest
+
+from pyspark import SparkContext, SparkConf
+
+
+class ConfTests(unittest.TestCase):
+ def test_memory_conf(self):
+ memoryList = ["1T", "1G", "1M", "1024K"]
+ for memory in memoryList:
+ sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory))
+ l = list(range(1024))
+ random.shuffle(l)
+ rdd = sc.parallelize(l, 4)
+ self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
+ sc.stop()
+
+
+if __name__ == "__main__":
+ from pyspark.tests.test_conf import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py
new file mode 100644
index 0000000000000..201baf420354d
--- /dev/null
+++ b/python/pyspark/tests/test_context.py
@@ -0,0 +1,258 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import shutil
+import tempfile
+import threading
+import time
+import unittest
+
+from pyspark import SparkFiles, SparkContext
+from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME
+
+
+class CheckpointTests(ReusedPySparkTestCase):
+
+ def setUp(self):
+ self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(self.checkpointDir.name)
+ self.sc.setCheckpointDir(self.checkpointDir.name)
+
+ def tearDown(self):
+ shutil.rmtree(self.checkpointDir.name)
+
+ def test_basic_checkpointing(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
+
+ flatMappedRDD.checkpoint()
+ result = flatMappedRDD.collect()
+ time.sleep(1) # 1 second
+ self.assertTrue(flatMappedRDD.isCheckpointed())
+ self.assertEqual(flatMappedRDD.collect(), result)
+ self.assertEqual("file:" + self.checkpointDir.name,
+ os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile())))
+
+ def test_checkpoint_and_restore(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertTrue(flatMappedRDD.getCheckpointFile() is None)
+
+ flatMappedRDD.checkpoint()
+ flatMappedRDD.count() # forces a checkpoint to be computed
+ time.sleep(1) # 1 second
+
+ self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
+ flatMappedRDD._jrdd_deserializer)
+ self.assertEqual([1, 2, 3, 4], recovered.collect())
+
+
+class LocalCheckpointTests(ReusedPySparkTestCase):
+
+ def test_basic_localcheckpointing(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertFalse(flatMappedRDD.isLocallyCheckpointed())
+
+ flatMappedRDD.localCheckpoint()
+ result = flatMappedRDD.collect()
+ time.sleep(1) # 1 second
+ self.assertTrue(flatMappedRDD.isCheckpointed())
+ self.assertTrue(flatMappedRDD.isLocallyCheckpointed())
+ self.assertEqual(flatMappedRDD.collect(), result)
+
+
+class AddFileTests(PySparkTestCase):
+
+ def test_add_py_file(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this job fails due to `userlibrary` not being on the Python path:
+ # disable logging in log4j temporarily
+ def func(x):
+ from userlibrary import UserClass
+ return UserClass().hello()
+ with QuietTest(self.sc):
+ self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first)
+
+ # Add the file, so the job should now succeed:
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addPyFile(path)
+ res = self.sc.parallelize(range(2)).map(func).first()
+ self.assertEqual("Hello World!", res)
+
+ def test_add_file_locally(self):
+ path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
+ self.sc.addFile(path)
+ download_path = SparkFiles.get("hello.txt")
+ self.assertNotEqual(path, download_path)
+ with open(download_path) as test_file:
+ self.assertEqual("Hello World!\n", test_file.readline())
+
+ def test_add_file_recursively_locally(self):
+ path = os.path.join(SPARK_HOME, "python/test_support/hello")
+ self.sc.addFile(path, True)
+ download_path = SparkFiles.get("hello")
+ self.assertNotEqual(path, download_path)
+ with open(download_path + "/hello.txt") as test_file:
+ self.assertEqual("Hello World!\n", test_file.readline())
+ with open(download_path + "/sub_hello/sub_hello.txt") as test_file:
+ self.assertEqual("Sub Hello World!\n", test_file.readline())
+
+ def test_add_py_file_locally(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this fails due to `userlibrary` not being on the Python path:
+ def func():
+ from userlibrary import UserClass
+ self.assertRaises(ImportError, func)
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addPyFile(path)
+ from userlibrary import UserClass
+ self.assertEqual("Hello World!", UserClass().hello())
+
+ def test_add_egg_file_locally(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this fails due to `userlibrary` not being on the Python path:
+ def func():
+ from userlib import UserClass
+ self.assertRaises(ImportError, func)
+ path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip")
+ self.sc.addPyFile(path)
+ from userlib import UserClass
+ self.assertEqual("Hello World from inside a package!", UserClass().hello())
+
+ def test_overwrite_system_module(self):
+ self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py"))
+
+ import SimpleHTTPServer
+ self.assertEqual("My Server", SimpleHTTPServer.__name__)
+
+ def func(x):
+ import SimpleHTTPServer
+ return SimpleHTTPServer.__name__
+
+ self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect())
+
+
+class ContextTests(unittest.TestCase):
+
+ def test_failed_sparkcontext_creation(self):
+ # Regression test for SPARK-1550
+ self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name"))
+
+ def test_get_or_create(self):
+ with SparkContext.getOrCreate() as sc:
+ self.assertTrue(SparkContext.getOrCreate() is sc)
+
+ def test_parallelize_eager_cleanup(self):
+ with SparkContext() as sc:
+ temp_files = os.listdir(sc._temp_dir)
+ rdd = sc.parallelize([0, 1, 2])
+ post_parallalize_temp_files = os.listdir(sc._temp_dir)
+ self.assertEqual(temp_files, post_parallalize_temp_files)
+
+ def test_set_conf(self):
+ # This is for an internal use case. When there is an existing SparkContext,
+ # SparkSession's builder needs to set configs into SparkContext's conf.
+ sc = SparkContext()
+ sc._conf.set("spark.test.SPARK16224", "SPARK16224")
+ self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224")
+ sc.stop()
+
+ def test_stop(self):
+ sc = SparkContext()
+ self.assertNotEqual(SparkContext._active_spark_context, None)
+ sc.stop()
+ self.assertEqual(SparkContext._active_spark_context, None)
+
+ def test_with(self):
+ with SparkContext() as sc:
+ self.assertNotEqual(SparkContext._active_spark_context, None)
+ self.assertEqual(SparkContext._active_spark_context, None)
+
+ def test_with_exception(self):
+ try:
+ with SparkContext() as sc:
+ self.assertNotEqual(SparkContext._active_spark_context, None)
+ raise Exception()
+ except:
+ pass
+ self.assertEqual(SparkContext._active_spark_context, None)
+
+ def test_with_stop(self):
+ with SparkContext() as sc:
+ self.assertNotEqual(SparkContext._active_spark_context, None)
+ sc.stop()
+ self.assertEqual(SparkContext._active_spark_context, None)
+
+ def test_progress_api(self):
+ with SparkContext() as sc:
+ sc.setJobGroup('test_progress_api', '', True)
+ rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
+
+ def run():
+ try:
+ rdd.count()
+ except Exception:
+ pass
+ t = threading.Thread(target=run)
+ t.daemon = True
+ t.start()
+ # wait for scheduler to start
+ time.sleep(1)
+
+ tracker = sc.statusTracker()
+ jobIds = tracker.getJobIdsForGroup('test_progress_api')
+ self.assertEqual(1, len(jobIds))
+ job = tracker.getJobInfo(jobIds[0])
+ self.assertEqual(1, len(job.stageIds))
+ stage = tracker.getStageInfo(job.stageIds[0])
+ self.assertEqual(rdd.getNumPartitions(), stage.numTasks)
+
+ sc.cancelAllJobs()
+ t.join()
+ # wait for event listener to update the status
+ time.sleep(1)
+
+ job = tracker.getJobInfo(jobIds[0])
+ self.assertEqual('FAILED', job.status)
+ self.assertEqual([], tracker.getActiveJobsIds())
+ self.assertEqual([], tracker.getActiveStageIds())
+
+ sc.stop()
+
+ def test_startTime(self):
+ with SparkContext() as sc:
+ self.assertGreater(sc.startTime, 0)
+
+
+if __name__ == "__main__":
+ from pyspark.tests.test_context import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_daemon.py b/python/pyspark/tests/test_daemon.py
new file mode 100644
index 0000000000000..fccd74fff1516
--- /dev/null
+++ b/python/pyspark/tests/test_daemon.py
@@ -0,0 +1,80 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import sys
+import time
+import unittest
+
+from pyspark.serializers import read_int
+
+
+class DaemonTests(unittest.TestCase):
+ def connect(self, port):
+ from socket import socket, AF_INET, SOCK_STREAM
+ sock = socket(AF_INET, SOCK_STREAM)
+ sock.connect(('127.0.0.1', port))
+ # send a split index of -1 to shutdown the worker
+ sock.send(b"\xFF\xFF\xFF\xFF")
+ sock.close()
+ return True
+
+ def do_termination_test(self, terminator):
+ from subprocess import Popen, PIPE
+ from errno import ECONNREFUSED
+
+ # start daemon
+ daemon_path = os.path.join(os.path.dirname(__file__), "..", "daemon.py")
+ python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON")
+ daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE)
+
+ # read the port number
+ port = read_int(daemon.stdout)
+
+ # daemon should accept connections
+ self.assertTrue(self.connect(port))
+
+ # request shutdown
+ terminator(daemon)
+ time.sleep(1)
+
+ # daemon should no longer accept connections
+ try:
+ self.connect(port)
+ except EnvironmentError as exception:
+ self.assertEqual(exception.errno, ECONNREFUSED)
+ else:
+ self.fail("Expected EnvironmentError to be raised")
+
+ def test_termination_stdin(self):
+ """Ensure that daemon and workers terminate when stdin is closed."""
+ self.do_termination_test(lambda daemon: daemon.stdin.close())
+
+ def test_termination_sigterm(self):
+ """Ensure that daemon and workers terminate on SIGTERM."""
+ from signal import SIGTERM
+ self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+
+
+if __name__ == "__main__":
+ from pyspark.tests.test_daemon import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_join.py b/python/pyspark/tests/test_join.py
new file mode 100644
index 0000000000000..e97e695f8b20d
--- /dev/null
+++ b/python/pyspark/tests/test_join.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.testing.utils import ReusedPySparkTestCase
+
+
+class JoinTests(ReusedPySparkTestCase):
+
+ def test_narrow_dependency_in_join(self):
+ rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x))
+ parted = rdd.partitionBy(2)
+ self.assertEqual(2, parted.union(parted).getNumPartitions())
+ self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
+ self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
+
+ tracker = self.sc.statusTracker()
+
+ self.sc.setJobGroup("test1", "test", True)
+ d = sorted(parted.join(parted).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual((0, (0, 0)), d[0])
+ jobId = tracker.getJobIdsForGroup("test1")[0]
+ self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test2", "test", True)
+ d = sorted(parted.join(rdd).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual((0, (0, 0)), d[0])
+ jobId = tracker.getJobIdsForGroup("test2")[0]
+ self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test3", "test", True)
+ d = sorted(parted.cogroup(parted).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual([[0], [0]], list(map(list, d[0][1])))
+ jobId = tracker.getJobIdsForGroup("test3")[0]
+ self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test4", "test", True)
+ d = sorted(parted.cogroup(rdd).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual([[0], [0]], list(map(list, d[0][1])))
+ jobId = tracker.getJobIdsForGroup("test4")[0]
+ self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.tests.test_join import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_profiler.py b/python/pyspark/tests/test_profiler.py
new file mode 100644
index 0000000000000..56cbcff01657c
--- /dev/null
+++ b/python/pyspark/tests/test_profiler.py
@@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import tempfile
+import unittest
+
+from pyspark import SparkConf, SparkContext, BasicProfiler
+from pyspark.testing.utils import PySparkTestCase
+
+if sys.version >= "3":
+ from io import StringIO
+else:
+ from StringIO import StringIO
+
+
+class ProfilerTests(PySparkTestCase):
+
+ def setUp(self):
+ self._old_sys_path = list(sys.path)
+ class_name = self.__class__.__name__
+ conf = SparkConf().set("spark.python.profile", "true")
+ self.sc = SparkContext('local[4]', class_name, conf=conf)
+
+ def test_profiler(self):
+ self.do_computation()
+
+ profilers = self.sc.profiler_collector.profilers
+ self.assertEqual(1, len(profilers))
+ id, profiler, _ = profilers[0]
+ stats = profiler.stats()
+ self.assertTrue(stats is not None)
+ width, stat_list = stats.get_print_list([])
+ func_names = [func_name for fname, n, func_name in stat_list]
+ self.assertTrue("heavy_foo" in func_names)
+
+ old_stdout = sys.stdout
+ sys.stdout = io = StringIO()
+ self.sc.show_profiles()
+ self.assertTrue("heavy_foo" in io.getvalue())
+ sys.stdout = old_stdout
+
+ d = tempfile.gettempdir()
+ self.sc.dump_profiles(d)
+ self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
+
+ def test_custom_profiler(self):
+ class TestCustomProfiler(BasicProfiler):
+ def show(self, id):
+ self.result = "Custom formatting"
+
+ self.sc.profiler_collector.profiler_cls = TestCustomProfiler
+
+ self.do_computation()
+
+ profilers = self.sc.profiler_collector.profilers
+ self.assertEqual(1, len(profilers))
+ _, profiler, _ = profilers[0]
+ self.assertTrue(isinstance(profiler, TestCustomProfiler))
+
+ self.sc.show_profiles()
+ self.assertEqual("Custom formatting", profiler.result)
+
+ def do_computation(self):
+ def heavy_foo(x):
+ for i in range(1 << 18):
+ x = 1
+
+ rdd = self.sc.parallelize(range(100))
+ rdd.foreach(heavy_foo)
+
+
+class ProfilerTests2(unittest.TestCase):
+ def test_profiler_disabled(self):
+ sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false"))
+ try:
+ self.assertRaisesRegexp(
+ RuntimeError,
+ "'spark.python.profile' configuration must be set",
+ lambda: sc.show_profiles())
+ self.assertRaisesRegexp(
+ RuntimeError,
+ "'spark.python.profile' configuration must be set",
+ lambda: sc.dump_profiles("/tmp/abc"))
+ finally:
+ sc.stop()
+
+
+if __name__ == "__main__":
+ from pyspark.tests.test_profiler import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
new file mode 100644
index 0000000000000..b2a544b8de78a
--- /dev/null
+++ b/python/pyspark/tests/test_rdd.py
@@ -0,0 +1,739 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import hashlib
+import os
+import random
+import sys
+import tempfile
+from glob import glob
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark import shuffle, RDD
+from pyspark.serializers import CloudPickleSerializer, BatchedSerializer, PickleSerializer,\
+ MarshalSerializer, UTF8Deserializer, NoOpSerializer
+from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest
+
+if sys.version_info[0] >= 3:
+ xrange = range
+
+
+class RDDTests(ReusedPySparkTestCase):
+
+ def test_range(self):
+ self.assertEqual(self.sc.range(1, 1).count(), 0)
+ self.assertEqual(self.sc.range(1, 0, -1).count(), 1)
+ self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2)
+
+ def test_id(self):
+ rdd = self.sc.parallelize(range(10))
+ id = rdd.id()
+ self.assertEqual(id, rdd.id())
+ rdd2 = rdd.map(str).filter(bool)
+ id2 = rdd2.id()
+ self.assertEqual(id + 1, id2)
+ self.assertEqual(id2, rdd2.id())
+
+ def test_empty_rdd(self):
+ rdd = self.sc.emptyRDD()
+ self.assertTrue(rdd.isEmpty())
+
+ def test_sum(self):
+ self.assertEqual(0, self.sc.emptyRDD().sum())
+ self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())
+
+ def test_to_localiterator(self):
+ from time import sleep
+ rdd = self.sc.parallelize([1, 2, 3])
+ it = rdd.toLocalIterator()
+ sleep(5)
+ self.assertEqual([1, 2, 3], sorted(it))
+
+ rdd2 = rdd.repartition(1000)
+ it2 = rdd2.toLocalIterator()
+ sleep(5)
+ self.assertEqual([1, 2, 3], sorted(it2))
+
+ def test_save_as_textfile_with_unicode(self):
+ # Regression test for SPARK-970
+ x = u"\u00A1Hola, mundo!"
+ data = self.sc.parallelize([x])
+ tempFile = tempfile.NamedTemporaryFile(delete=True)
+ tempFile.close()
+ data.saveAsTextFile(tempFile.name)
+ raw_contents = b''.join(open(p, 'rb').read()
+ for p in glob(tempFile.name + "/part-0000*"))
+ self.assertEqual(x, raw_contents.strip().decode("utf-8"))
+
+ def test_save_as_textfile_with_utf8(self):
+ x = u"\u00A1Hola, mundo!"
+ data = self.sc.parallelize([x.encode("utf-8")])
+ tempFile = tempfile.NamedTemporaryFile(delete=True)
+ tempFile.close()
+ data.saveAsTextFile(tempFile.name)
+ raw_contents = b''.join(open(p, 'rb').read()
+ for p in glob(tempFile.name + "/part-0000*"))
+ self.assertEqual(x, raw_contents.strip().decode('utf8'))
+
+ def test_transforming_cartesian_result(self):
+ # Regression test for SPARK-1034
+ rdd1 = self.sc.parallelize([1, 2])
+ rdd2 = self.sc.parallelize([3, 4])
+ cart = rdd1.cartesian(rdd2)
+ result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect()
+
+ def test_transforming_pickle_file(self):
+ # Regression test for SPARK-2601
+ data = self.sc.parallelize([u"Hello", u"World!"])
+ tempFile = tempfile.NamedTemporaryFile(delete=True)
+ tempFile.close()
+ data.saveAsPickleFile(tempFile.name)
+ pickled_file = self.sc.pickleFile(tempFile.name)
+ pickled_file.map(lambda x: x).collect()
+
+ def test_cartesian_on_textfile(self):
+ # Regression test for
+ path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
+ a = self.sc.textFile(path)
+ result = a.cartesian(a).collect()
+ (x, y) = result[0]
+ self.assertEqual(u"Hello World!", x.strip())
+ self.assertEqual(u"Hello World!", y.strip())
+
+ def test_cartesian_chaining(self):
+ # Tests for SPARK-16589
+ rdd = self.sc.parallelize(range(10), 2)
+ self.assertSetEqual(
+ set(rdd.cartesian(rdd).cartesian(rdd).collect()),
+ set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)])
+ )
+
+ self.assertSetEqual(
+ set(rdd.cartesian(rdd.cartesian(rdd)).collect()),
+ set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)])
+ )
+
+ self.assertSetEqual(
+ set(rdd.cartesian(rdd.zip(rdd)).collect()),
+ set([(x, (y, y)) for x in range(10) for y in range(10)])
+ )
+
+ def test_zip_chaining(self):
+ # Tests for SPARK-21985
+ rdd = self.sc.parallelize('abc', 2)
+ self.assertSetEqual(
+ set(rdd.zip(rdd).zip(rdd).collect()),
+ set([((x, x), x) for x in 'abc'])
+ )
+ self.assertSetEqual(
+ set(rdd.zip(rdd.zip(rdd)).collect()),
+ set([(x, (x, x)) for x in 'abc'])
+ )
+
+ def test_deleting_input_files(self):
+ # Regression test for SPARK-1025
+ tempFile = tempfile.NamedTemporaryFile(delete=False)
+ tempFile.write(b"Hello World!")
+ tempFile.close()
+ data = self.sc.textFile(tempFile.name)
+ filtered_data = data.filter(lambda x: True)
+ self.assertEqual(1, filtered_data.count())
+ os.unlink(tempFile.name)
+ with QuietTest(self.sc):
+ self.assertRaises(Exception, lambda: filtered_data.count())
+
+ def test_sampling_default_seed(self):
+ # Test for SPARK-3995 (default seed setting)
+ data = self.sc.parallelize(xrange(1000), 1)
+ subset = data.takeSample(False, 10)
+ self.assertEqual(len(subset), 10)
+
+ def test_aggregate_mutable_zero_value(self):
+ # Test for SPARK-9021; uses aggregate and treeAggregate to build dict
+ # representing a counter of ints
+ # NOTE: dict is used instead of collections.Counter for Python 2.6
+ # compatibility
+ from collections import defaultdict
+
+ # Show that single or multiple partitions work
+ data1 = self.sc.range(10, numSlices=1)
+ data2 = self.sc.range(10, numSlices=2)
+
+ def seqOp(x, y):
+ x[y] += 1
+ return x
+
+ def comboOp(x, y):
+ for key, val in y.items():
+ x[key] += val
+ return x
+
+ counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp)
+ counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp)
+ counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
+ counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
+
+ ground_truth = defaultdict(int, dict((i, 1) for i in range(10)))
+ self.assertEqual(counts1, ground_truth)
+ self.assertEqual(counts2, ground_truth)
+ self.assertEqual(counts3, ground_truth)
+ self.assertEqual(counts4, ground_truth)
+
+ def test_aggregate_by_key_mutable_zero_value(self):
+ # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that
+ # contains lists of all values for each key in the original RDD
+
+ # list(range(...)) for Python 3.x compatibility (can't use * operator
+ # on a range object)
+ # list(zip(...)) for Python 3.x compatibility (want to parallelize a
+ # collection, not a zip object)
+ tuples = list(zip(list(range(10))*2, [1]*20))
+ # Show that single or multiple partitions work
+ data1 = self.sc.parallelize(tuples, 1)
+ data2 = self.sc.parallelize(tuples, 2)
+
+ def seqOp(x, y):
+ x.append(y)
+ return x
+
+ def comboOp(x, y):
+ x.extend(y)
+ return x
+
+ values1 = data1.aggregateByKey([], seqOp, comboOp).collect()
+ values2 = data2.aggregateByKey([], seqOp, comboOp).collect()
+ # Sort lists to ensure clean comparison with ground_truth
+ values1.sort()
+ values2.sort()
+
+ ground_truth = [(i, [1]*2) for i in range(10)]
+ self.assertEqual(values1, ground_truth)
+ self.assertEqual(values2, ground_truth)
+
+ def test_fold_mutable_zero_value(self):
+ # Test for SPARK-9021; uses fold to merge an RDD of dict counters into
+ # a single dict
+ # NOTE: dict is used instead of collections.Counter for Python 2.6
+ # compatibility
+ from collections import defaultdict
+
+ counts1 = defaultdict(int, dict((i, 1) for i in range(10)))
+ counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8)))
+ counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7)))
+ counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6)))
+ all_counts = [counts1, counts2, counts3, counts4]
+ # Show that single or multiple partitions work
+ data1 = self.sc.parallelize(all_counts, 1)
+ data2 = self.sc.parallelize(all_counts, 2)
+
+ def comboOp(x, y):
+ for key, val in y.items():
+ x[key] += val
+ return x
+
+ fold1 = data1.fold(defaultdict(int), comboOp)
+ fold2 = data2.fold(defaultdict(int), comboOp)
+
+ ground_truth = defaultdict(int)
+ for counts in all_counts:
+ for key, val in counts.items():
+ ground_truth[key] += val
+ self.assertEqual(fold1, ground_truth)
+ self.assertEqual(fold2, ground_truth)
+
+ def test_fold_by_key_mutable_zero_value(self):
+ # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains
+ # lists of all values for each key in the original RDD
+
+ tuples = [(i, range(i)) for i in range(10)]*2
+ # Show that single or multiple partitions work
+ data1 = self.sc.parallelize(tuples, 1)
+ data2 = self.sc.parallelize(tuples, 2)
+
+ def comboOp(x, y):
+ x.extend(y)
+ return x
+
+ values1 = data1.foldByKey([], comboOp).collect()
+ values2 = data2.foldByKey([], comboOp).collect()
+ # Sort lists to ensure clean comparison with ground_truth
+ values1.sort()
+ values2.sort()
+
+ # list(range(...)) for Python 3.x compatibility
+ ground_truth = [(i, list(range(i))*2) for i in range(10)]
+ self.assertEqual(values1, ground_truth)
+ self.assertEqual(values2, ground_truth)
+
+ def test_aggregate_by_key(self):
+ data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
+
+ def seqOp(x, y):
+ x.add(y)
+ return x
+
+ def combOp(x, y):
+ x |= y
+ return x
+
+ sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect())
+ self.assertEqual(3, len(sets))
+ self.assertEqual(set([1]), sets[1])
+ self.assertEqual(set([2]), sets[3])
+ self.assertEqual(set([1, 3]), sets[5])
+
+ def test_itemgetter(self):
+ rdd = self.sc.parallelize([range(10)])
+ from operator import itemgetter
+ self.assertEqual([1], rdd.map(itemgetter(1)).collect())
+ self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect())
+
+ def test_namedtuple_in_rdd(self):
+ from collections import namedtuple
+ Person = namedtuple("Person", "id firstName lastName")
+ jon = Person(1, "Jon", "Doe")
+ jane = Person(2, "Jane", "Doe")
+ theDoes = self.sc.parallelize([jon, jane])
+ self.assertEqual([jon, jane], theDoes.collect())
+
+ def test_large_broadcast(self):
+ N = 10000
+ data = [[float(i) for i in range(300)] for i in range(N)]
+ bdata = self.sc.broadcast(data) # 27MB
+ m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ self.assertEqual(N, m)
+
+ def test_unpersist(self):
+ N = 1000
+ data = [[float(i) for i in range(300)] for i in range(N)]
+ bdata = self.sc.broadcast(data) # 3MB
+ bdata.unpersist()
+ m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ self.assertEqual(N, m)
+ bdata.destroy()
+ try:
+ self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ except Exception as e:
+ pass
+ else:
+ raise Exception("job should fail after destroy the broadcast")
+
+ def test_multiple_broadcasts(self):
+ N = 1 << 21
+ b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
+ r = list(range(1 << 15))
+ random.shuffle(r)
+ s = str(r).encode()
+ checksum = hashlib.md5(s).hexdigest()
+ b2 = self.sc.broadcast(s)
+ r = list(set(self.sc.parallelize(range(10), 10).map(
+ lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+ self.assertEqual(1, len(r))
+ size, csum = r[0]
+ self.assertEqual(N, size)
+ self.assertEqual(checksum, csum)
+
+ random.shuffle(r)
+ s = str(r).encode()
+ checksum = hashlib.md5(s).hexdigest()
+ b2 = self.sc.broadcast(s)
+ r = list(set(self.sc.parallelize(range(10), 10).map(
+ lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+ self.assertEqual(1, len(r))
+ size, csum = r[0]
+ self.assertEqual(N, size)
+ self.assertEqual(checksum, csum)
+
+ def test_multithread_broadcast_pickle(self):
+ import threading
+
+ b1 = self.sc.broadcast(list(range(3)))
+ b2 = self.sc.broadcast(list(range(3)))
+
+ def f1():
+ return b1.value
+
+ def f2():
+ return b2.value
+
+ funcs_num_pickled = {f1: None, f2: None}
+
+ def do_pickle(f, sc):
+ command = (f, None, sc.serializer, sc.serializer)
+ ser = CloudPickleSerializer()
+ ser.dumps(command)
+
+ def process_vars(sc):
+ broadcast_vars = list(sc._pickled_broadcast_vars)
+ num_pickled = len(broadcast_vars)
+ sc._pickled_broadcast_vars.clear()
+ return num_pickled
+
+ def run(f, sc):
+ do_pickle(f, sc)
+ funcs_num_pickled[f] = process_vars(sc)
+
+ # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage
+ do_pickle(f1, self.sc)
+
+ # run all for f2, should only add/count/clear b2 from worker thread local storage
+ t = threading.Thread(target=run, args=(f2, self.sc))
+ t.start()
+ t.join()
+
+ # count number of vars pickled in main thread, only b1 should be counted and cleared
+ funcs_num_pickled[f1] = process_vars(self.sc)
+
+ self.assertEqual(funcs_num_pickled[f1], 1)
+ self.assertEqual(funcs_num_pickled[f2], 1)
+ self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)
+
+ def test_large_closure(self):
+ N = 200000
+ data = [float(i) for i in xrange(N)]
+ rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
+ self.assertEqual(N, rdd.first())
+ # regression test for SPARK-6886
+ self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
+
+ def test_zip_with_different_serializers(self):
+ a = self.sc.parallelize(range(5))
+ b = self.sc.parallelize(range(100, 105))
+ self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
+ a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
+ b = b._reserialize(MarshalSerializer())
+ self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
+ # regression test for SPARK-4841
+ path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
+ t = self.sc.textFile(path)
+ cnt = t.count()
+ self.assertEqual(cnt, t.zip(t).count())
+ rdd = t.map(str)
+ self.assertEqual(cnt, t.zip(rdd).count())
+ # regression test for bug in _reserializer()
+ self.assertEqual(cnt, t.zip(rdd).count())
+
+ def test_zip_with_different_object_sizes(self):
+ # regress test for SPARK-5973
+ a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i)
+ b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i)
+ self.assertEqual(10000, a.zip(b).count())
+
+ def test_zip_with_different_number_of_items(self):
+ a = self.sc.parallelize(range(5), 2)
+ # different number of partitions
+ b = self.sc.parallelize(range(100, 106), 3)
+ self.assertRaises(ValueError, lambda: a.zip(b))
+ with QuietTest(self.sc):
+ # different number of batched items in JVM
+ b = self.sc.parallelize(range(100, 104), 2)
+ self.assertRaises(Exception, lambda: a.zip(b).count())
+ # different number of items in one pair
+ b = self.sc.parallelize(range(100, 106), 2)
+ self.assertRaises(Exception, lambda: a.zip(b).count())
+ # same total number of items, but different distributions
+ a = self.sc.parallelize([2, 3], 2).flatMap(range)
+ b = self.sc.parallelize([3, 2], 2).flatMap(range)
+ self.assertEqual(a.count(), b.count())
+ self.assertRaises(Exception, lambda: a.zip(b).count())
+
+ def test_count_approx_distinct(self):
+ rdd = self.sc.parallelize(xrange(1000))
+ self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
+ self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
+ self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
+ self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050)
+
+ rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
+ self.assertTrue(18 < rdd.countApproxDistinct() < 22)
+ self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22)
+ self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22)
+ self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22)
+
+ self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001))
+
+ def test_histogram(self):
+ # empty
+ rdd = self.sc.parallelize([])
+ self.assertEqual([0], rdd.histogram([0, 10])[1])
+ self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
+ self.assertRaises(ValueError, lambda: rdd.histogram(1))
+
+ # out of range
+ rdd = self.sc.parallelize([10.01, -0.01])
+ self.assertEqual([0], rdd.histogram([0, 10])[1])
+ self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1])
+
+ # in range with one bucket
+ rdd = self.sc.parallelize(range(1, 5))
+ self.assertEqual([4], rdd.histogram([0, 10])[1])
+ self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1])
+
+ # in range with one bucket exact match
+ self.assertEqual([4], rdd.histogram([1, 4])[1])
+
+ # out of range with two buckets
+ rdd = self.sc.parallelize([10.01, -0.01])
+ self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1])
+
+ # out of range with two uneven buckets
+ rdd = self.sc.parallelize([10.01, -0.01])
+ self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
+
+ # in range with two buckets
+ rdd = self.sc.parallelize([1, 2, 3, 5, 6])
+ self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
+
+ # in range with two bucket and None
+ rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')])
+ self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
+
+ # in range with two uneven buckets
+ rdd = self.sc.parallelize([1, 2, 3, 5, 6])
+ self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1])
+
+ # mixed range with two uneven buckets
+ rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01])
+ self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1])
+
+ # mixed range with four uneven buckets
+ rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1])
+ self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
+
+ # mixed range with uneven buckets and NaN
+ rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0,
+ 199.0, 200.0, 200.1, None, float('nan')])
+ self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
+
+ # out of range with infinite buckets
+ rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")])
+ self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1])
+
+ # invalid buckets
+ self.assertRaises(ValueError, lambda: rdd.histogram([]))
+ self.assertRaises(ValueError, lambda: rdd.histogram([1]))
+ self.assertRaises(ValueError, lambda: rdd.histogram(0))
+ self.assertRaises(TypeError, lambda: rdd.histogram({}))
+
+ # without buckets
+ rdd = self.sc.parallelize(range(1, 5))
+ self.assertEqual(([1, 4], [4]), rdd.histogram(1))
+
+ # without buckets single element
+ rdd = self.sc.parallelize([1])
+ self.assertEqual(([1, 1], [1]), rdd.histogram(1))
+
+ # without bucket no range
+ rdd = self.sc.parallelize([1] * 4)
+ self.assertEqual(([1, 1], [4]), rdd.histogram(1))
+
+ # without buckets basic two
+ rdd = self.sc.parallelize(range(1, 5))
+ self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2))
+
+ # without buckets with more requested than elements
+ rdd = self.sc.parallelize([1, 2])
+ buckets = [1 + 0.2 * i for i in range(6)]
+ hist = [1, 0, 0, 0, 1]
+ self.assertEqual((buckets, hist), rdd.histogram(5))
+
+ # invalid RDDs
+ rdd = self.sc.parallelize([1, float('inf')])
+ self.assertRaises(ValueError, lambda: rdd.histogram(2))
+ rdd = self.sc.parallelize([float('nan')])
+ self.assertRaises(ValueError, lambda: rdd.histogram(2))
+
+ # string
+ rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2)
+ self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1])
+ self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1))
+ self.assertRaises(TypeError, lambda: rdd.histogram(2))
+
+ def test_repartitionAndSortWithinPartitions_asc(self):
+ rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
+
+ repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True)
+ partitions = repartitioned.glom().collect()
+ self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)])
+ self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)])
+
+ def test_repartitionAndSortWithinPartitions_desc(self):
+ rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
+
+ repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False)
+ partitions = repartitioned.glom().collect()
+ self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)])
+ self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)])
+
+ def test_repartition_no_skewed(self):
+ num_partitions = 20
+ a = self.sc.parallelize(range(int(1000)), 2)
+ l = a.repartition(num_partitions).glom().map(len).collect()
+ zeros = len([x for x in l if x == 0])
+ self.assertTrue(zeros == 0)
+ l = a.coalesce(num_partitions, True).glom().map(len).collect()
+ zeros = len([x for x in l if x == 0])
+ self.assertTrue(zeros == 0)
+
+ def test_repartition_on_textfile(self):
+ path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
+ rdd = self.sc.textFile(path)
+ result = rdd.repartition(1).collect()
+ self.assertEqual(u"Hello World!", result[0])
+
+ def test_distinct(self):
+ rdd = self.sc.parallelize((1, 2, 3)*10, 10)
+ self.assertEqual(rdd.getNumPartitions(), 10)
+ self.assertEqual(rdd.distinct().count(), 3)
+ result = rdd.distinct(5)
+ self.assertEqual(result.getNumPartitions(), 5)
+ self.assertEqual(result.count(), 3)
+
+ def test_external_group_by_key(self):
+ self.sc._conf.set("spark.python.worker.memory", "1m")
+ N = 200001
+ kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
+ gkv = kv.groupByKey().cache()
+ self.assertEqual(3, gkv.count())
+ filtered = gkv.filter(lambda kv: kv[0] == 1)
+ self.assertEqual(1, filtered.count())
+ self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect())
+ self.assertEqual([(N // 3, N // 3)],
+ filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
+ result = filtered.collect()[0][1]
+ self.assertEqual(N // 3, len(result))
+ self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList))
+
+ def test_sort_on_empty_rdd(self):
+ self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
+
+ def test_sample(self):
+ rdd = self.sc.parallelize(range(0, 100), 4)
+ wo = rdd.sample(False, 0.1, 2).collect()
+ wo_dup = rdd.sample(False, 0.1, 2).collect()
+ self.assertSetEqual(set(wo), set(wo_dup))
+ wr = rdd.sample(True, 0.2, 5).collect()
+ wr_dup = rdd.sample(True, 0.2, 5).collect()
+ self.assertSetEqual(set(wr), set(wr_dup))
+ wo_s10 = rdd.sample(False, 0.3, 10).collect()
+ wo_s20 = rdd.sample(False, 0.3, 20).collect()
+ self.assertNotEqual(set(wo_s10), set(wo_s20))
+ wr_s11 = rdd.sample(True, 0.4, 11).collect()
+ wr_s21 = rdd.sample(True, 0.4, 21).collect()
+ self.assertNotEqual(set(wr_s11), set(wr_s21))
+
+ def test_null_in_rdd(self):
+ jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc)
+ rdd = RDD(jrdd, self.sc, UTF8Deserializer())
+ self.assertEqual([u"a", None, u"b"], rdd.collect())
+ rdd = RDD(jrdd, self.sc, NoOpSerializer())
+ self.assertEqual([b"a", None, b"b"], rdd.collect())
+
+ def test_multiple_python_java_RDD_conversions(self):
+ # Regression test for SPARK-5361
+ data = [
+ (u'1', {u'director': u'David Lean'}),
+ (u'2', {u'director': u'Andrew Dominik'})
+ ]
+ data_rdd = self.sc.parallelize(data)
+ data_java_rdd = data_rdd._to_java_object_rdd()
+ data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd)
+ converted_rdd = RDD(data_python_rdd, self.sc)
+ self.assertEqual(2, converted_rdd.count())
+
+ # conversion between python and java RDD threw exceptions
+ data_java_rdd = converted_rdd._to_java_object_rdd()
+ data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd)
+ converted_rdd = RDD(data_python_rdd, self.sc)
+ self.assertEqual(2, converted_rdd.count())
+
+ # Regression test for SPARK-6294
+ def test_take_on_jrdd(self):
+ rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x))
+ rdd._jrdd.first()
+
+ def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
+ # Regression test for SPARK-5969
+ seq = [(i * 59 % 101, i) for i in range(101)] # unsorted sequence
+ rdd = self.sc.parallelize(seq)
+ for ascending in [True, False]:
+ sort = rdd.sortByKey(ascending=ascending, numPartitions=5)
+ self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending))
+ sizes = sort.glom().map(len).collect()
+ for size in sizes:
+ self.assertGreater(size, 0)
+
+ def test_pipe_functions(self):
+ data = ['1', '2', '3']
+ rdd = self.sc.parallelize(data)
+ with QuietTest(self.sc):
+ self.assertEqual([], rdd.pipe('cc').collect())
+ self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect)
+ result = rdd.pipe('cat').collect()
+ result.sort()
+ for x, y in zip(data, result):
+ self.assertEqual(x, y)
+ self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
+ self.assertEqual([], rdd.pipe('grep 4').collect())
+
+ def test_pipe_unicode(self):
+ # Regression test for SPARK-20947
+ data = [u'\u6d4b\u8bd5', '1']
+ rdd = self.sc.parallelize(data)
+ result = rdd.pipe('cat').collect()
+ self.assertEqual(data, result)
+
+ def test_stopiteration_in_user_code(self):
+
+ def stopit(*x):
+ raise StopIteration()
+
+ seq_rdd = self.sc.parallelize(range(10))
+ keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
+ msg = "Caught StopIteration thrown from user's code; failing the task"
+
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
+ self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
+ self.assertRaisesRegexp(Py4JJavaError, msg,
+ seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
+
+ # these methods call the user function both in the driver and in the executor
+ # the exception raised is different according to where the StopIteration happens
+ # RuntimeError is raised if in the driver
+ # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker)
+ self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+ keyed_rdd.reduceByKeyLocally, stopit)
+ self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+ seq_rdd.aggregate, 0, stopit, lambda *x: 1)
+ self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+ seq_rdd.aggregate, 0, lambda *x: 1, stopit)
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.tests.test_rdd import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_readwrite.py b/python/pyspark/tests/test_readwrite.py
new file mode 100644
index 0000000000000..e45f5b371f461
--- /dev/null
+++ b/python/pyspark/tests/test_readwrite.py
@@ -0,0 +1,499 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import shutil
+import sys
+import tempfile
+import unittest
+from array import array
+
+from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME
+
+
+class InputFormatTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name)
+
+ @unittest.skipIf(sys.version >= "3", "serialize array of byte")
+ def test_sequencefiles(self):
+ basepath = self.tempdir.name
+ ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text").collect())
+ ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
+ self.assertEqual(ints, ei)
+
+ doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/",
+ "org.apache.hadoop.io.DoubleWritable",
+ "org.apache.hadoop.io.Text").collect())
+ ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')]
+ self.assertEqual(doubles, ed)
+
+ bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.BytesWritable").collect())
+ ebs = [(1, bytearray('aa', 'utf-8')),
+ (1, bytearray('aa', 'utf-8')),
+ (2, bytearray('aa', 'utf-8')),
+ (2, bytearray('bb', 'utf-8')),
+ (2, bytearray('bb', 'utf-8')),
+ (3, bytearray('cc', 'utf-8'))]
+ self.assertEqual(bytes, ebs)
+
+ text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/",
+ "org.apache.hadoop.io.Text",
+ "org.apache.hadoop.io.Text").collect())
+ et = [(u'1', u'aa'),
+ (u'1', u'aa'),
+ (u'2', u'aa'),
+ (u'2', u'bb'),
+ (u'2', u'bb'),
+ (u'3', u'cc')]
+ self.assertEqual(text, et)
+
+ bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.BooleanWritable").collect())
+ eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)]
+ self.assertEqual(bools, eb)
+
+ nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.BooleanWritable").collect())
+ en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
+ self.assertEqual(nulls, en)
+
+ maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.MapWritable").collect()
+ em = [(1, {}),
+ (1, {3.0: u'bb'}),
+ (2, {1.0: u'aa'}),
+ (2, {1.0: u'cc'}),
+ (3, {2.0: u'dd'})]
+ for v in maps:
+ self.assertTrue(v in em)
+
+ # arrays get pickled to tuples by default
+ tuples = sorted(self.sc.sequenceFile(
+ basepath + "/sftestdata/sfarray/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.spark.api.python.DoubleArrayWritable").collect())
+ et = [(1, ()),
+ (2, (3.0, 4.0, 5.0)),
+ (3, (4.0, 5.0, 6.0))]
+ self.assertEqual(tuples, et)
+
+ # with custom converters, primitive arrays can stay as arrays
+ arrays = sorted(self.sc.sequenceFile(
+ basepath + "/sftestdata/sfarray/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.spark.api.python.DoubleArrayWritable",
+ valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect())
+ ea = [(1, array('d')),
+ (2, array('d', [3.0, 4.0, 5.0])),
+ (3, array('d', [4.0, 5.0, 6.0]))]
+ self.assertEqual(arrays, ea)
+
+ clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
+ "org.apache.hadoop.io.Text",
+ "org.apache.spark.api.python.TestWritable").collect())
+ cname = u'org.apache.spark.api.python.TestWritable'
+ ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}),
+ (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}),
+ (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}),
+ (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}),
+ (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})]
+ self.assertEqual(clazz, ec)
+
+ unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
+ "org.apache.hadoop.io.Text",
+ "org.apache.spark.api.python.TestWritable",
+ ).collect())
+ self.assertEqual(unbatched_clazz, ec)
+
+ def test_oldhadoop(self):
+ basepath = self.tempdir.name
+ ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapred.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text").collect())
+ ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
+ self.assertEqual(ints, ei)
+
+ hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
+ oldconf = {"mapreduce.input.fileinputformat.inputdir": hellopath}
+ hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat",
+ "org.apache.hadoop.io.LongWritable",
+ "org.apache.hadoop.io.Text",
+ conf=oldconf).collect()
+ result = [(0, u'Hello World!')]
+ self.assertEqual(hello, result)
+
+ def test_newhadoop(self):
+ basepath = self.tempdir.name
+ ints = sorted(self.sc.newAPIHadoopFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text").collect())
+ ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
+ self.assertEqual(ints, ei)
+
+ hellopath = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
+ newconf = {"mapreduce.input.fileinputformat.inputdir": hellopath}
+ hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat",
+ "org.apache.hadoop.io.LongWritable",
+ "org.apache.hadoop.io.Text",
+ conf=newconf).collect()
+ result = [(0, u'Hello World!')]
+ self.assertEqual(hello, result)
+
+ def test_newolderror(self):
+ basepath = self.tempdir.name
+ self.assertRaises(Exception, lambda: self.sc.hadoopFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text"))
+
+ self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapred.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text"))
+
+ def test_bad_inputs(self):
+ basepath = self.tempdir.name
+ self.assertRaises(Exception, lambda: self.sc.sequenceFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.io.NotValidWritable",
+ "org.apache.hadoop.io.Text"))
+ self.assertRaises(Exception, lambda: self.sc.hadoopFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapred.NotValidInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text"))
+ self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text"))
+
+ def test_converters(self):
+ # use of custom converters
+ basepath = self.tempdir.name
+ maps = sorted(self.sc.sequenceFile(
+ basepath + "/sftestdata/sfmap/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.MapWritable",
+ keyConverter="org.apache.spark.api.python.TestInputKeyConverter",
+ valueConverter="org.apache.spark.api.python.TestInputValueConverter").collect())
+ em = [(u'\x01', []),
+ (u'\x01', [3.0]),
+ (u'\x02', [1.0]),
+ (u'\x02', [1.0]),
+ (u'\x03', [2.0])]
+ self.assertEqual(maps, em)
+
+ def test_binary_files(self):
+ path = os.path.join(self.tempdir.name, "binaryfiles")
+ os.mkdir(path)
+ data = b"short binary data"
+ with open(os.path.join(path, "part-0000"), 'wb') as f:
+ f.write(data)
+ [(p, d)] = self.sc.binaryFiles(path).collect()
+ self.assertTrue(p.endswith("part-0000"))
+ self.assertEqual(d, data)
+
+ def test_binary_records(self):
+ path = os.path.join(self.tempdir.name, "binaryrecords")
+ os.mkdir(path)
+ with open(os.path.join(path, "part-0000"), 'w') as f:
+ for i in range(100):
+ f.write('%04d' % i)
+ result = self.sc.binaryRecords(path, 4).map(int).collect()
+ self.assertEqual(list(range(100)), result)
+
+
+class OutputFormatTests(ReusedPySparkTestCase):
+
+ def setUp(self):
+ self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(self.tempdir.name)
+
+ def tearDown(self):
+ shutil.rmtree(self.tempdir.name, ignore_errors=True)
+
+ @unittest.skipIf(sys.version >= "3", "serialize array of byte")
+ def test_sequencefiles(self):
+ basepath = self.tempdir.name
+ ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
+ self.sc.parallelize(ei).saveAsSequenceFile(basepath + "/sfint/")
+ ints = sorted(self.sc.sequenceFile(basepath + "/sfint/").collect())
+ self.assertEqual(ints, ei)
+
+ ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')]
+ self.sc.parallelize(ed).saveAsSequenceFile(basepath + "/sfdouble/")
+ doubles = sorted(self.sc.sequenceFile(basepath + "/sfdouble/").collect())
+ self.assertEqual(doubles, ed)
+
+ ebs = [(1, bytearray(b'\x00\x07spam\x08')), (2, bytearray(b'\x00\x07spam\x08'))]
+ self.sc.parallelize(ebs).saveAsSequenceFile(basepath + "/sfbytes/")
+ bytes = sorted(self.sc.sequenceFile(basepath + "/sfbytes/").collect())
+ self.assertEqual(bytes, ebs)
+
+ et = [(u'1', u'aa'),
+ (u'2', u'bb'),
+ (u'3', u'cc')]
+ self.sc.parallelize(et).saveAsSequenceFile(basepath + "/sftext/")
+ text = sorted(self.sc.sequenceFile(basepath + "/sftext/").collect())
+ self.assertEqual(text, et)
+
+ eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)]
+ self.sc.parallelize(eb).saveAsSequenceFile(basepath + "/sfbool/")
+ bools = sorted(self.sc.sequenceFile(basepath + "/sfbool/").collect())
+ self.assertEqual(bools, eb)
+
+ en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
+ self.sc.parallelize(en).saveAsSequenceFile(basepath + "/sfnull/")
+ nulls = sorted(self.sc.sequenceFile(basepath + "/sfnull/").collect())
+ self.assertEqual(nulls, en)
+
+ em = [(1, {}),
+ (1, {3.0: u'bb'}),
+ (2, {1.0: u'aa'}),
+ (2, {1.0: u'cc'}),
+ (3, {2.0: u'dd'})]
+ self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/")
+ maps = self.sc.sequenceFile(basepath + "/sfmap/").collect()
+ for v in maps:
+ self.assertTrue(v, em)
+
+ def test_oldhadoop(self):
+ basepath = self.tempdir.name
+ dict_data = [(1, {}),
+ (1, {"row1": 1.0}),
+ (2, {"row2": 2.0})]
+ self.sc.parallelize(dict_data).saveAsHadoopFile(
+ basepath + "/oldhadoop/",
+ "org.apache.hadoop.mapred.SequenceFileOutputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.MapWritable")
+ result = self.sc.hadoopFile(
+ basepath + "/oldhadoop/",
+ "org.apache.hadoop.mapred.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.MapWritable").collect()
+ for v in result:
+ self.assertTrue(v, dict_data)
+
+ conf = {
+ "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat",
+ "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable",
+ "mapreduce.job.output.value.class": "org.apache.hadoop.io.MapWritable",
+ "mapreduce.output.fileoutputformat.outputdir": basepath + "/olddataset/"
+ }
+ self.sc.parallelize(dict_data).saveAsHadoopDataset(conf)
+ input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/olddataset/"}
+ result = self.sc.hadoopRDD(
+ "org.apache.hadoop.mapred.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.MapWritable",
+ conf=input_conf).collect()
+ for v in result:
+ self.assertTrue(v, dict_data)
+
+ def test_newhadoop(self):
+ basepath = self.tempdir.name
+ data = [(1, ""),
+ (1, "a"),
+ (2, "bcdf")]
+ self.sc.parallelize(data).saveAsNewAPIHadoopFile(
+ basepath + "/newhadoop/",
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text")
+ result = sorted(self.sc.newAPIHadoopFile(
+ basepath + "/newhadoop/",
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text").collect())
+ self.assertEqual(result, data)
+
+ conf = {
+ "mapreduce.job.outputformat.class":
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+ "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable",
+ "mapreduce.job.output.value.class": "org.apache.hadoop.io.Text",
+ "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/"
+ }
+ self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf)
+ input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"}
+ new_dataset = sorted(self.sc.newAPIHadoopRDD(
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text",
+ conf=input_conf).collect())
+ self.assertEqual(new_dataset, data)
+
+ @unittest.skipIf(sys.version >= "3", "serialize of array")
+ def test_newhadoop_with_array(self):
+ basepath = self.tempdir.name
+ # use custom ArrayWritable types and converters to handle arrays
+ array_data = [(1, array('d')),
+ (1, array('d', [1.0, 2.0, 3.0])),
+ (2, array('d', [3.0, 4.0, 5.0]))]
+ self.sc.parallelize(array_data).saveAsNewAPIHadoopFile(
+ basepath + "/newhadoop/",
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.spark.api.python.DoubleArrayWritable",
+ valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter")
+ result = sorted(self.sc.newAPIHadoopFile(
+ basepath + "/newhadoop/",
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.spark.api.python.DoubleArrayWritable",
+ valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect())
+ self.assertEqual(result, array_data)
+
+ conf = {
+ "mapreduce.job.outputformat.class":
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+ "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable",
+ "mapreduce.job.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable",
+ "mapreduce.output.fileoutputformat.outputdir": basepath + "/newdataset/"
+ }
+ self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset(
+ conf,
+ valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter")
+ input_conf = {"mapreduce.input.fileinputformat.inputdir": basepath + "/newdataset/"}
+ new_dataset = sorted(self.sc.newAPIHadoopRDD(
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.spark.api.python.DoubleArrayWritable",
+ valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter",
+ conf=input_conf).collect())
+ self.assertEqual(new_dataset, array_data)
+
+ def test_newolderror(self):
+ basepath = self.tempdir.name
+ rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x))
+ self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile(
+ basepath + "/newolderror/saveAsHadoopFile/",
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat"))
+ self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile(
+ basepath + "/newolderror/saveAsNewAPIHadoopFile/",
+ "org.apache.hadoop.mapred.SequenceFileOutputFormat"))
+
+ def test_bad_inputs(self):
+ basepath = self.tempdir.name
+ rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x))
+ self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile(
+ basepath + "/badinputs/saveAsHadoopFile/",
+ "org.apache.hadoop.mapred.NotValidOutputFormat"))
+ self.assertRaises(Exception, lambda: rdd.saveAsNewAPIHadoopFile(
+ basepath + "/badinputs/saveAsNewAPIHadoopFile/",
+ "org.apache.hadoop.mapreduce.lib.output.NotValidOutputFormat"))
+
+ def test_converters(self):
+ # use of custom converters
+ basepath = self.tempdir.name
+ data = [(1, {3.0: u'bb'}),
+ (2, {1.0: u'aa'}),
+ (3, {2.0: u'dd'})]
+ self.sc.parallelize(data).saveAsNewAPIHadoopFile(
+ basepath + "/converters/",
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+ keyConverter="org.apache.spark.api.python.TestOutputKeyConverter",
+ valueConverter="org.apache.spark.api.python.TestOutputValueConverter")
+ converted = sorted(self.sc.sequenceFile(basepath + "/converters/").collect())
+ expected = [(u'1', 3.0),
+ (u'2', 1.0),
+ (u'3', 2.0)]
+ self.assertEqual(converted, expected)
+
+ def test_reserialization(self):
+ basepath = self.tempdir.name
+ x = range(1, 5)
+ y = range(1001, 1005)
+ data = list(zip(x, y))
+ rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y))
+ rdd.saveAsSequenceFile(basepath + "/reserialize/sequence")
+ result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect())
+ self.assertEqual(result1, data)
+
+ rdd.saveAsHadoopFile(
+ basepath + "/reserialize/hadoop",
+ "org.apache.hadoop.mapred.SequenceFileOutputFormat")
+ result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect())
+ self.assertEqual(result2, data)
+
+ rdd.saveAsNewAPIHadoopFile(
+ basepath + "/reserialize/newhadoop",
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")
+ result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect())
+ self.assertEqual(result3, data)
+
+ conf4 = {
+ "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat",
+ "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable",
+ "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable",
+ "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/dataset"}
+ rdd.saveAsHadoopDataset(conf4)
+ result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect())
+ self.assertEqual(result4, data)
+
+ conf5 = {"mapreduce.job.outputformat.class":
+ "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat",
+ "mapreduce.job.output.key.class": "org.apache.hadoop.io.IntWritable",
+ "mapreduce.job.output.value.class": "org.apache.hadoop.io.IntWritable",
+ "mapreduce.output.fileoutputformat.outputdir": basepath + "/reserialize/newdataset"
+ }
+ rdd.saveAsNewAPIHadoopDataset(conf5)
+ result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect())
+ self.assertEqual(result5, data)
+
+ def test_malformed_RDD(self):
+ basepath = self.tempdir.name
+ # non-batch-serialized RDD[[(K, V)]] should be rejected
+ data = [[(1, "a")], [(2, "aa")], [(3, "aaa")]]
+ rdd = self.sc.parallelize(data, len(data))
+ self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile(
+ basepath + "/malformed/sequence"))
+
+
+if __name__ == "__main__":
+ from pyspark.tests.test_readwrite import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py
new file mode 100644
index 0000000000000..bce94062c8af7
--- /dev/null
+++ b/python/pyspark/tests/test_serializers.py
@@ -0,0 +1,237 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import math
+import sys
+import unittest
+
+from pyspark import serializers
+from pyspark.serializers import *
+from pyspark.serializers import CloudPickleSerializer, CompressedSerializer, \
+ AutoBatchedSerializer, BatchedSerializer, AutoSerializer, NoOpSerializer, PairDeserializer, \
+ FlattenedValuesSerializer, CartesianDeserializer
+from pyspark.testing.utils import PySparkTestCase, read_int, write_int, ByteArrayOutput, \
+ have_numpy, have_scipy
+
+
+class SerializationTestCase(unittest.TestCase):
+
+ def test_namedtuple(self):
+ from collections import namedtuple
+ from pickle import dumps, loads
+ P = namedtuple("P", "x y")
+ p1 = P(1, 3)
+ p2 = loads(dumps(p1, 2))
+ self.assertEqual(p1, p2)
+
+ from pyspark.cloudpickle import dumps
+ P2 = loads(dumps(P))
+ p3 = P2(1, 3)
+ self.assertEqual(p1, p3)
+
+ def test_itemgetter(self):
+ from operator import itemgetter
+ ser = CloudPickleSerializer()
+ d = range(10)
+ getter = itemgetter(1)
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ getter = itemgetter(0, 3)
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ def test_function_module_name(self):
+ ser = CloudPickleSerializer()
+ func = lambda x: x
+ func2 = ser.loads(ser.dumps(func))
+ self.assertEqual(func.__module__, func2.__module__)
+
+ def test_attrgetter(self):
+ from operator import attrgetter
+ ser = CloudPickleSerializer()
+
+ class C(object):
+ def __getattr__(self, item):
+ return item
+ d = C()
+ getter = attrgetter("a")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+ getter = attrgetter("a", "b")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ d.e = C()
+ getter = attrgetter("e.a")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+ getter = attrgetter("e.a", "e.b")
+ getter2 = ser.loads(ser.dumps(getter))
+ self.assertEqual(getter(d), getter2(d))
+
+ # Regression test for SPARK-3415
+ def test_pickling_file_handles(self):
+ # to be corrected with SPARK-11160
+ try:
+ import xmlrunner
+ except ImportError:
+ ser = CloudPickleSerializer()
+ out1 = sys.stderr
+ out2 = ser.loads(ser.dumps(out1))
+ self.assertEqual(out1, out2)
+
+ def test_func_globals(self):
+
+ class Unpicklable(object):
+ def __reduce__(self):
+ raise Exception("not picklable")
+
+ global exit
+ exit = Unpicklable()
+
+ ser = CloudPickleSerializer()
+ self.assertRaises(Exception, lambda: ser.dumps(exit))
+
+ def foo():
+ sys.exit(0)
+
+ self.assertTrue("exit" in foo.__code__.co_names)
+ ser.dumps(foo)
+
+ def test_compressed_serializer(self):
+ ser = CompressedSerializer(PickleSerializer())
+ try:
+ from StringIO import StringIO
+ except ImportError:
+ from io import BytesIO as StringIO
+ io = StringIO()
+ ser.dump_stream(["abc", u"123", range(5)], io)
+ io.seek(0)
+ self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
+ ser.dump_stream(range(1000), io)
+ io.seek(0)
+ self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io)))
+ io.close()
+
+ def test_hash_serializer(self):
+ hash(NoOpSerializer())
+ hash(UTF8Deserializer())
+ hash(PickleSerializer())
+ hash(MarshalSerializer())
+ hash(AutoSerializer())
+ hash(BatchedSerializer(PickleSerializer()))
+ hash(AutoBatchedSerializer(MarshalSerializer()))
+ hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer()))
+ hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer()))
+ hash(CompressedSerializer(PickleSerializer()))
+ hash(FlattenedValuesSerializer(PickleSerializer()))
+
+
+@unittest.skipIf(not have_scipy, "SciPy not installed")
+class SciPyTests(PySparkTestCase):
+
+ """General PySpark tests that depend on scipy """
+
+ def test_serialize(self):
+ from scipy.special import gammaln
+
+ x = range(1, 5)
+ expected = list(map(gammaln, x))
+ observed = self.sc.parallelize(x).map(gammaln).collect()
+ self.assertEqual(expected, observed)
+
+
+@unittest.skipIf(not have_numpy, "NumPy not installed")
+class NumPyTests(PySparkTestCase):
+
+ """General PySpark tests that depend on numpy """
+
+ def test_statcounter_array(self):
+ import numpy as np
+
+ x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])])
+ s = x.stats()
+ self.assertSequenceEqual([2.0, 2.0], s.mean().tolist())
+ self.assertSequenceEqual([1.0, 1.0], s.min().tolist())
+ self.assertSequenceEqual([3.0, 3.0], s.max().tolist())
+ self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist())
+
+ stats_dict = s.asDict()
+ self.assertEqual(3, stats_dict['count'])
+ self.assertSequenceEqual([2.0, 2.0], stats_dict['mean'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_dict['min'].tolist())
+ self.assertSequenceEqual([3.0, 3.0], stats_dict['max'].tolist())
+ self.assertSequenceEqual([6.0, 6.0], stats_dict['sum'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_dict['stdev'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_dict['variance'].tolist())
+
+ stats_sample_dict = s.asDict(sample=True)
+ self.assertEqual(3, stats_dict['count'])
+ self.assertSequenceEqual([2.0, 2.0], stats_sample_dict['mean'].tolist())
+ self.assertSequenceEqual([1.0, 1.0], stats_sample_dict['min'].tolist())
+ self.assertSequenceEqual([3.0, 3.0], stats_sample_dict['max'].tolist())
+ self.assertSequenceEqual([6.0, 6.0], stats_sample_dict['sum'].tolist())
+ self.assertSequenceEqual(
+ [0.816496580927726, 0.816496580927726], stats_sample_dict['stdev'].tolist())
+ self.assertSequenceEqual(
+ [0.6666666666666666, 0.6666666666666666], stats_sample_dict['variance'].tolist())
+
+
+class SerializersTest(unittest.TestCase):
+
+ def test_chunked_stream(self):
+ original_bytes = bytearray(range(100))
+ for data_length in [1, 10, 100]:
+ for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]:
+ dest = ByteArrayOutput()
+ stream_out = serializers.ChunkedStream(dest, buffer_length)
+ stream_out.write(original_bytes[:data_length])
+ stream_out.close()
+ num_chunks = int(math.ceil(float(data_length) / buffer_length))
+ # length for each chunk, and a final -1 at the very end
+ exp_size = (num_chunks + 1) * 4 + data_length
+ self.assertEqual(len(dest.buffer), exp_size)
+ dest_pos = 0
+ data_pos = 0
+ for chunk_idx in range(num_chunks):
+ chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)])
+ if chunk_idx == num_chunks - 1:
+ exp_length = data_length % buffer_length
+ if exp_length == 0:
+ exp_length = buffer_length
+ else:
+ exp_length = buffer_length
+ self.assertEqual(chunk_length, exp_length)
+ dest_pos += 4
+ dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length]
+ orig_chunk = original_bytes[data_pos:data_pos + chunk_length]
+ self.assertEqual(dest_chunk, orig_chunk)
+ dest_pos += chunk_length
+ data_pos += chunk_length
+ # ends with a -1
+ self.assertEqual(dest.buffer[-4:], write_int(-1))
+
+
+if __name__ == "__main__":
+ from pyspark.tests.test_serializers import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py
new file mode 100644
index 0000000000000..0489426061b75
--- /dev/null
+++ b/python/pyspark/tests/test_shuffle.py
@@ -0,0 +1,181 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import random
+import sys
+import unittest
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext
+from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
+
+if sys.version_info[0] >= 3:
+ xrange = range
+
+
+class MergerTests(unittest.TestCase):
+
+ def setUp(self):
+ self.N = 1 << 12
+ self.l = [i for i in xrange(self.N)]
+ self.data = list(zip(self.l, self.l))
+ self.agg = Aggregator(lambda x: [x],
+ lambda x, y: x.append(y) or x,
+ lambda x, y: x.extend(y) or x)
+
+ def test_small_dataset(self):
+ m = ExternalMerger(self.agg, 1000)
+ m.mergeValues(self.data)
+ self.assertEqual(m.spills, 0)
+ self.assertEqual(sum(sum(v) for k, v in m.items()),
+ sum(xrange(self.N)))
+
+ m = ExternalMerger(self.agg, 1000)
+ m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data))
+ self.assertEqual(m.spills, 0)
+ self.assertEqual(sum(sum(v) for k, v in m.items()),
+ sum(xrange(self.N)))
+
+ def test_medium_dataset(self):
+ m = ExternalMerger(self.agg, 20)
+ m.mergeValues(self.data)
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.items()),
+ sum(xrange(self.N)))
+
+ m = ExternalMerger(self.agg, 10)
+ m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3))
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.items()),
+ sum(xrange(self.N)) * 3)
+
+ def test_huge_dataset(self):
+ m = ExternalMerger(self.agg, 5, partitions=3)
+ m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10))
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(len(v) for k, v in m.items()),
+ self.N * 10)
+ m._cleanup()
+
+ def test_group_by_key(self):
+
+ def gen_data(N, step):
+ for i in range(1, N + 1, step):
+ for j in range(i):
+ yield (i, [j])
+
+ def gen_gs(N, step=1):
+ return shuffle.GroupByKey(gen_data(N, step))
+
+ self.assertEqual(1, len(list(gen_gs(1))))
+ self.assertEqual(2, len(list(gen_gs(2))))
+ self.assertEqual(100, len(list(gen_gs(100))))
+ self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)])
+ self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100)))
+
+ for k, vs in gen_gs(50002, 10000):
+ self.assertEqual(k, len(vs))
+ self.assertEqual(list(range(k)), list(vs))
+
+ ser = PickleSerializer()
+ l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
+ for k, vs in l:
+ self.assertEqual(k, len(vs))
+ self.assertEqual(list(range(k)), list(vs))
+
+ def test_stopiteration_is_raised(self):
+
+ def stopit(*args, **kwargs):
+ raise StopIteration()
+
+ def legit_create_combiner(x):
+ return [x]
+
+ def legit_merge_value(x, y):
+ return x.append(y) or x
+
+ def legit_merge_combiners(x, y):
+ return x.extend(y) or x
+
+ data = [(x % 2, x) for x in range(100)]
+
+ # wrong create combiner
+ m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20)
+ with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+ m.mergeValues(data)
+
+ # wrong merge value
+ m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20)
+ with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+ m.mergeValues(data)
+
+ # wrong merge combiners
+ m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20)
+ with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
+ m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
+
+
+class SorterTests(unittest.TestCase):
+ def test_in_memory_sort(self):
+ l = list(range(1024))
+ random.shuffle(l)
+ sorter = ExternalSorter(1024)
+ self.assertEqual(sorted(l), list(sorter.sorted(l)))
+ self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
+ self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
+ self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
+ list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
+
+ def test_external_sort(self):
+ class CustomizedSorter(ExternalSorter):
+ def _next_limit(self):
+ return self.memory_limit
+ l = list(range(1024))
+ random.shuffle(l)
+ sorter = CustomizedSorter(1)
+ self.assertEqual(sorted(l), list(sorter.sorted(l)))
+ self.assertGreater(shuffle.DiskBytesSpilled, 0)
+ last = shuffle.DiskBytesSpilled
+ self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
+ self.assertGreater(shuffle.DiskBytesSpilled, last)
+ last = shuffle.DiskBytesSpilled
+ self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
+ self.assertGreater(shuffle.DiskBytesSpilled, last)
+ last = shuffle.DiskBytesSpilled
+ self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
+ list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
+ self.assertGreater(shuffle.DiskBytesSpilled, last)
+
+ def test_external_sort_in_rdd(self):
+ conf = SparkConf().set("spark.python.worker.memory", "1m")
+ sc = SparkContext(conf=conf)
+ l = list(range(10240))
+ random.shuffle(l)
+ rdd = sc.parallelize(l, 4)
+ self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
+ sc.stop()
+
+
+if __name__ == "__main__":
+ from pyspark.tests.test_shuffle import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py
new file mode 100644
index 0000000000000..b3a967440a9b2
--- /dev/null
+++ b/python/pyspark/tests/test_taskcontext.py
@@ -0,0 +1,161 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import random
+import sys
+import time
+
+from pyspark import SparkContext, TaskContext, BarrierTaskContext
+from pyspark.testing.utils import PySparkTestCase
+
+
+class TaskContextTests(PySparkTestCase):
+
+ def setUp(self):
+ self._old_sys_path = list(sys.path)
+ class_name = self.__class__.__name__
+ # Allow retries even though they are normally disabled in local mode
+ self.sc = SparkContext('local[4, 2]', class_name)
+
+ def test_stage_id(self):
+ """Test the stage ids are available and incrementing as expected."""
+ rdd = self.sc.parallelize(range(10))
+ stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
+ stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
+ # Test using the constructor directly rather than the get()
+ stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0]
+ self.assertEqual(stage1 + 1, stage2)
+ self.assertEqual(stage1 + 2, stage3)
+ self.assertEqual(stage2 + 1, stage3)
+
+ def test_partition_id(self):
+ """Test the partition id."""
+ rdd1 = self.sc.parallelize(range(10), 1)
+ rdd2 = self.sc.parallelize(range(10), 2)
+ pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect()
+ pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect()
+ self.assertEqual(0, pids1[0])
+ self.assertEqual(0, pids1[9])
+ self.assertEqual(0, pids2[0])
+ self.assertEqual(1, pids2[9])
+
+ def test_attempt_number(self):
+ """Verify the attempt numbers are correctly reported."""
+ rdd = self.sc.parallelize(range(10))
+ # Verify a simple job with no failures
+ attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect()
+ map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers)
+
+ def fail_on_first(x):
+ """Fail on the first attempt so we get a positive attempt number"""
+ tc = TaskContext.get()
+ attempt_number = tc.attemptNumber()
+ partition_id = tc.partitionId()
+ attempt_id = tc.taskAttemptId()
+ if attempt_number == 0 and partition_id == 0:
+ raise Exception("Failing on first attempt")
+ else:
+ return [x, partition_id, attempt_number, attempt_id]
+ result = rdd.map(fail_on_first).collect()
+ # We should re-submit the first partition to it but other partitions should be attempt 0
+ self.assertEqual([0, 0, 1], result[0][0:3])
+ self.assertEqual([9, 3, 0], result[9][0:3])
+ first_partition = filter(lambda x: x[1] == 0, result)
+ map(lambda x: self.assertEqual(1, x[2]), first_partition)
+ other_partitions = filter(lambda x: x[1] != 0, result)
+ map(lambda x: self.assertEqual(0, x[2]), other_partitions)
+ # The task attempt id should be different
+ self.assertTrue(result[0][3] != result[9][3])
+
+ def test_tc_on_driver(self):
+ """Verify that getting the TaskContext on the driver returns None."""
+ tc = TaskContext.get()
+ self.assertTrue(tc is None)
+
+ def test_get_local_property(self):
+ """Verify that local properties set on the driver are available in TaskContext."""
+ key = "testkey"
+ value = "testvalue"
+ self.sc.setLocalProperty(key, value)
+ try:
+ rdd = self.sc.parallelize(range(1), 1)
+ prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0]
+ self.assertEqual(prop1, value)
+ prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
+ self.assertTrue(prop2 is None)
+ finally:
+ self.sc.setLocalProperty(key, None)
+
+ def test_barrier(self):
+ """
+ Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks
+ within a stage.
+ """
+ rdd = self.sc.parallelize(range(10), 4)
+
+ def f(iterator):
+ yield sum(iterator)
+
+ def context_barrier(x):
+ tc = BarrierTaskContext.get()
+ time.sleep(random.randint(1, 10))
+ tc.barrier()
+ return time.time()
+
+ times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
+ self.assertTrue(max(times) - min(times) < 1)
+
+ def test_barrier_with_python_worker_reuse(self):
+ """
+ Verify that BarrierTaskContext.barrier() with reused python worker.
+ """
+ self.sc._conf.set("spark.python.work.reuse", "true")
+ rdd = self.sc.parallelize(range(4), 4)
+ # start a normal job first to start all worker
+ result = rdd.map(lambda x: x ** 2).collect()
+ self.assertEqual([0, 1, 4, 9], result)
+ # make sure `spark.python.work.reuse=true`
+ self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true")
+
+ # worker will be reused in this barrier job
+ self.test_barrier()
+
+ def test_barrier_infos(self):
+ """
+ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
+ barrier stage.
+ """
+ rdd = self.sc.parallelize(range(10), 4)
+
+ def f(iterator):
+ yield sum(iterator)
+
+ taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get()
+ .getTaskInfos()).collect()
+ self.assertTrue(len(taskInfos) == 4)
+ self.assertTrue(len(taskInfos[0]) == 4)
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.tests.test_taskcontext import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py
new file mode 100644
index 0000000000000..11cda8fd2f5cd
--- /dev/null
+++ b/python/pyspark/tests/test_util.py
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark import keyword_only
+from pyspark.testing.utils import PySparkTestCase
+
+
+class KeywordOnlyTests(unittest.TestCase):
+ class Wrapped(object):
+ @keyword_only
+ def set(self, x=None, y=None):
+ if "x" in self._input_kwargs:
+ self._x = self._input_kwargs["x"]
+ if "y" in self._input_kwargs:
+ self._y = self._input_kwargs["y"]
+ return x, y
+
+ def test_keywords(self):
+ w = self.Wrapped()
+ x, y = w.set(y=1)
+ self.assertEqual(y, 1)
+ self.assertEqual(y, w._y)
+ self.assertIsNone(x)
+ self.assertFalse(hasattr(w, "_x"))
+
+ def test_non_keywords(self):
+ w = self.Wrapped()
+ self.assertRaises(TypeError, lambda: w.set(0, y=1))
+
+ def test_kwarg_ownership(self):
+ # test _input_kwargs is owned by each class instance and not a shared static variable
+ class Setter(object):
+ @keyword_only
+ def set(self, x=None, other=None, other_x=None):
+ if "other" in self._input_kwargs:
+ self._input_kwargs["other"].set(x=self._input_kwargs["other_x"])
+ self._x = self._input_kwargs["x"]
+
+ a = Setter()
+ b = Setter()
+ a.set(x=1, other=b, other_x=2)
+ self.assertEqual(a._x, 1)
+ self.assertEqual(b._x, 2)
+
+
+class UtilTests(PySparkTestCase):
+ def test_py4j_exception_message(self):
+ from pyspark.util import _exception_message
+
+ with self.assertRaises(Py4JJavaError) as context:
+ # This attempts java.lang.String(null) which throws an NPE.
+ self.sc._jvm.java.lang.String(None)
+
+ self.assertTrue('NullPointerException' in _exception_message(context.exception))
+
+ def test_parsing_version_string(self):
+ from pyspark.util import VersionUtils
+ self.assertRaises(ValueError, lambda: VersionUtils.majorMinorVersion("abced"))
+
+
+if __name__ == "__main__":
+ from pyspark.tests.test_util import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py
new file mode 100644
index 0000000000000..a33b77d983419
--- /dev/null
+++ b/python/pyspark/tests/test_worker.py
@@ -0,0 +1,157 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import sys
+import tempfile
+import threading
+import time
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest
+
+if sys.version_info[0] >= 3:
+ xrange = range
+
+
+class WorkerTests(ReusedPySparkTestCase):
+ def test_cancel_task(self):
+ temp = tempfile.NamedTemporaryFile(delete=True)
+ temp.close()
+ path = temp.name
+
+ def sleep(x):
+ import os
+ import time
+ with open(path, 'w') as f:
+ f.write("%d %d" % (os.getppid(), os.getpid()))
+ time.sleep(100)
+
+ # start job in background thread
+ def run():
+ try:
+ self.sc.parallelize(range(1), 1).foreach(sleep)
+ except Exception:
+ pass
+ import threading
+ t = threading.Thread(target=run)
+ t.daemon = True
+ t.start()
+
+ daemon_pid, worker_pid = 0, 0
+ while True:
+ if os.path.exists(path):
+ with open(path) as f:
+ data = f.read().split(' ')
+ daemon_pid, worker_pid = map(int, data)
+ break
+ time.sleep(0.1)
+
+ # cancel jobs
+ self.sc.cancelAllJobs()
+ t.join()
+
+ for i in range(50):
+ try:
+ os.kill(worker_pid, 0)
+ time.sleep(0.1)
+ except OSError:
+ break # worker was killed
+ else:
+ self.fail("worker has not been killed after 5 seconds")
+
+ try:
+ os.kill(daemon_pid, 0)
+ except OSError:
+ self.fail("daemon had been killed")
+
+ # run a normal job
+ rdd = self.sc.parallelize(xrange(100), 1)
+ self.assertEqual(100, rdd.map(str).count())
+
+ def test_after_exception(self):
+ def raise_exception(_):
+ raise Exception()
+ rdd = self.sc.parallelize(xrange(100), 1)
+ with QuietTest(self.sc):
+ self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
+ self.assertEqual(100, rdd.map(str).count())
+
+ def test_after_jvm_exception(self):
+ tempFile = tempfile.NamedTemporaryFile(delete=False)
+ tempFile.write(b"Hello World!")
+ tempFile.close()
+ data = self.sc.textFile(tempFile.name, 1)
+ filtered_data = data.filter(lambda x: True)
+ self.assertEqual(1, filtered_data.count())
+ os.unlink(tempFile.name)
+ with QuietTest(self.sc):
+ self.assertRaises(Exception, lambda: filtered_data.count())
+
+ rdd = self.sc.parallelize(xrange(100), 1)
+ self.assertEqual(100, rdd.map(str).count())
+
+ def test_accumulator_when_reuse_worker(self):
+ from pyspark.accumulators import INT_ACCUMULATOR_PARAM
+ acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+ self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
+ self.assertEqual(sum(range(100)), acc1.value)
+
+ acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+ self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
+ self.assertEqual(sum(range(100)), acc2.value)
+ self.assertEqual(sum(range(100)), acc1.value)
+
+ def test_reuse_worker_after_take(self):
+ rdd = self.sc.parallelize(xrange(100000), 1)
+ self.assertEqual(0, rdd.first())
+
+ def count():
+ try:
+ rdd.count()
+ except Exception:
+ pass
+
+ t = threading.Thread(target=count)
+ t.daemon = True
+ t.start()
+ t.join(5)
+ self.assertTrue(not t.isAlive())
+ self.assertEqual(100000, rdd.count())
+
+ def test_with_different_versions_of_python(self):
+ rdd = self.sc.parallelize(range(10))
+ rdd.count()
+ version = self.sc.pythonVer
+ self.sc.pythonVer = "2.0"
+ try:
+ with QuietTest(self.sc):
+ self.assertRaises(Py4JJavaError, lambda: rdd.count())
+ finally:
+ self.sc.pythonVer = version
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.tests.test_worker import *
+
+ try:
+ import xmlrunner
+ testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/run-tests.py b/python/run-tests.py
index 9e6b7d780409a..be5a60a28def6 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -60,9 +60,7 @@ def print_red(text):
LOGGER = logging.getLogger()
# Find out where the assembly jars are located.
-# Later, add back 2.12 to this list:
-# for scala in ["2.11", "2.12"]:
-for scala in ["2.11"]:
+for scala in ["2.11", "2.12"]:
build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala)
if os.path.isdir(build_dir):
SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*")
@@ -263,8 +261,9 @@ def main():
for module in modules_to_test:
if python_implementation not in module.blacklisted_python_implementations:
for test_goal in module.python_test_goals:
- if test_goal in ('pyspark.streaming.tests', 'pyspark.mllib.tests',
- 'pyspark.tests', 'pyspark.sql.tests'):
+ heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests',
+ 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests']
+ if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)):
priority = 0
else:
priority = 100
diff --git a/repl/pom.xml b/repl/pom.xml
index fa015b69d45d4..c7de67e41ca94 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -20,12 +20,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../pom.xml
- spark-repl_2.11
+ spark-repl_2.12jarSpark Project REPLhttp://spark.apache.org/
diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
index e5e2094368fb0..ac528ecb829b0 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
@@ -126,7 +126,7 @@ class ExecutorClassLoaderSuite
test("child first") {
val parentLoader = new URLClassLoader(urls2, null)
val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true)
- val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
+ val fakeClass = classLoader.loadClass("ReplFakeClass2").getConstructor().newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "1")
}
@@ -134,7 +134,7 @@ class ExecutorClassLoaderSuite
test("parent first") {
val parentLoader = new URLClassLoader(urls2, null)
val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, false)
- val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance()
+ val fakeClass = classLoader.loadClass("ReplFakeClass1").getConstructor().newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "2")
}
@@ -142,7 +142,7 @@ class ExecutorClassLoaderSuite
test("child first can fall back") {
val parentLoader = new URLClassLoader(urls2, null)
val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true)
- val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance()
+ val fakeClass = classLoader.loadClass("ReplFakeClass3").getConstructor().newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "2")
}
@@ -151,7 +151,7 @@ class ExecutorClassLoaderSuite
val parentLoader = new URLClassLoader(urls2, null)
val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true)
intercept[java.lang.ClassNotFoundException] {
- classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance()
+ classLoader.loadClass("ReplFakeClassDoesNotExist").getConstructor().newInstance()
}
}
@@ -202,11 +202,11 @@ class ExecutorClassLoaderSuite
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",
getClass().getClassLoader(), false)
- val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
+ val fakeClass = classLoader.loadClass("ReplFakeClass2").getConstructor().newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "1")
intercept[java.lang.ClassNotFoundException] {
- classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance()
+ classLoader.loadClass("ReplFakeClassDoesNotExist").getConstructor().newInstance()
}
}
diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml
index b89ea383bf872..8d594ee8f1478 100644
--- a/resource-managers/kubernetes/core/pom.xml
+++ b/resource-managers/kubernetes/core/pom.xml
@@ -19,12 +19,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../../pom.xml
- spark-kubernetes_2.11
+ spark-kubernetes_2.12jarSpark Project Kubernetes
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
index 3be93fd6059d0..43219cb9406f6 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
@@ -321,6 +321,7 @@ private[spark] object Config extends Logging {
val KUBERNETES_VOLUMES_PVC_TYPE = "persistentVolumeClaim"
val KUBERNETES_VOLUMES_EMPTYDIR_TYPE = "emptyDir"
val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path"
+ val KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY = "mount.subPath"
val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly"
val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path"
val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName"
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala
index b1762d1efe2ea..1a214fad96618 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala
@@ -34,5 +34,6 @@ private[spark] case class KubernetesEmptyDirVolumeConf(
private[spark] case class KubernetesVolumeSpec[T <: KubernetesVolumeSpecificConf](
volumeName: String,
mountPath: String,
+ mountSubPath: String,
mountReadOnly: Boolean,
volumeConf: T)
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala
index 713df5fffc3a2..155326469235b 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala
@@ -39,6 +39,7 @@ private[spark] object KubernetesVolumeUtils {
getVolumeTypesAndNames(properties).map { case (volumeType, volumeName) =>
val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY"
val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY"
+ val subPathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY"
for {
path <- properties.getTry(pathKey)
@@ -46,6 +47,7 @@ private[spark] object KubernetesVolumeUtils {
} yield KubernetesVolumeSpec(
volumeName = volumeName,
mountPath = path,
+ mountSubPath = properties.get(subPathKey).getOrElse(""),
mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean),
volumeConf = volumeConf
)
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
index e60259c4a9b5a..1473a7d3ee7f6 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
@@ -51,6 +51,7 @@ private[spark] class MountVolumesFeatureStep(
val volumeMount = new VolumeMountBuilder()
.withMountPath(spec.mountPath)
.withReadOnly(spec.mountReadOnly)
+ .withSubPath(spec.mountSubPath)
.withName(spec.volumeName)
.build()
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
index ad44dc63a90aa..3b5908f8f01d9 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
@@ -59,7 +59,7 @@ private[spark] class KubernetesDriverBuilder(
providePodTemplateConfigMapStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]
=> PodTemplateConfigMapStep) =
new PodTemplateConfigMapStep(_),
- provideInitialPod: () => SparkPod = SparkPod.initialPod) {
+ provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) {
import KubernetesDriverBuilder._
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
index e6b46308465d2..368d14bbc3cb5 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
@@ -57,7 +57,7 @@ private[spark] class KubernetesExecutorBuilder(
KubernetesConf[KubernetesExecutorSpecificConf]
=> HadoopSparkUserExecutorFeatureStep) =
new HadoopSparkUserExecutorFeatureStep(_),
- provideInitialPod: () => SparkPod = SparkPod.initialPod) {
+ provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) {
def buildFromFeatures(
kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = {
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala
index d795d159773a8..de79a58a3a756 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala
@@ -33,6 +33,18 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite {
KubernetesHostPathVolumeConf("/hostPath"))
}
+ test("Parses subPath correctly") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("test.emptyDir.volumeName.mount.path", "/path")
+ sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true")
+ sparkConf.set("test.emptyDir.volumeName.mount.subPath", "subPath")
+
+ val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head.get
+ assert(volumeSpec.volumeName === "volumeName")
+ assert(volumeSpec.mountPath === "/path")
+ assert(volumeSpec.mountSubPath === "subPath")
+ }
+
test("Parses persistentVolumeClaim volumes correctly") {
val sparkConf = new SparkConf(false)
sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path")
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala
index 82b828c04df69..9b5e42ebc849d 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala
@@ -44,6 +44,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite {
val volumeConf = KubernetesVolumeSpec(
"testVolume",
"/tmp",
+ "",
false,
KubernetesHostPathVolumeConf("/hostPath/tmp")
)
@@ -63,6 +64,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite {
val volumeConf = KubernetesVolumeSpec(
"testVolume",
"/tmp",
+ "",
true,
KubernetesPVCVolumeConf("pvcClaim")
)
@@ -84,6 +86,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite {
val volumeConf = KubernetesVolumeSpec(
"testVolume",
"/tmp",
+ "",
false,
KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G"))
)
@@ -105,6 +108,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite {
val volumeConf = KubernetesVolumeSpec(
"testVolume",
"/tmp",
+ "",
false,
KubernetesEmptyDirVolumeConf(None, None)
)
@@ -126,12 +130,14 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite {
val hpVolumeConf = KubernetesVolumeSpec(
"hpVolume",
"/tmp",
+ "",
false,
KubernetesHostPathVolumeConf("/hostPath/tmp")
)
val pvcVolumeConf = KubernetesVolumeSpec(
"checkpointVolume",
"/checkpoints",
+ "",
true,
KubernetesPVCVolumeConf("pvcClaim")
)
@@ -143,4 +149,77 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite {
assert(configuredPod.pod.getSpec.getVolumes.size() === 2)
assert(configuredPod.container.getVolumeMounts.size() === 2)
}
+
+ test("Mounts subpath on emptyDir") {
+ val volumeConf = KubernetesVolumeSpec(
+ "testVolume",
+ "/tmp",
+ "foo",
+ false,
+ KubernetesEmptyDirVolumeConf(None, None)
+ )
+ val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil)
+ val step = new MountVolumesFeatureStep(kubernetesConf)
+ val configuredPod = step.configurePod(SparkPod.initialPod())
+
+ assert(configuredPod.pod.getSpec.getVolumes.size() === 1)
+ val emptyDirMount = configuredPod.container.getVolumeMounts.get(0)
+ assert(emptyDirMount.getMountPath === "/tmp")
+ assert(emptyDirMount.getName === "testVolume")
+ assert(emptyDirMount.getSubPath === "foo")
+ }
+
+ test("Mounts subpath on persistentVolumeClaims") {
+ val volumeConf = KubernetesVolumeSpec(
+ "testVolume",
+ "/tmp",
+ "bar",
+ true,
+ KubernetesPVCVolumeConf("pvcClaim")
+ )
+ val kubernetesConf = emptyKubernetesConf.copy(roleVolumes = volumeConf :: Nil)
+ val step = new MountVolumesFeatureStep(kubernetesConf)
+ val configuredPod = step.configurePod(SparkPod.initialPod())
+
+ assert(configuredPod.pod.getSpec.getVolumes.size() === 1)
+ val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim
+ assert(pvcClaim.getClaimName === "pvcClaim")
+ assert(configuredPod.container.getVolumeMounts.size() === 1)
+ val pvcMount = configuredPod.container.getVolumeMounts.get(0)
+ assert(pvcMount.getMountPath === "/tmp")
+ assert(pvcMount.getName === "testVolume")
+ assert(pvcMount.getSubPath === "bar")
+ }
+
+ test("Mounts multiple subpaths") {
+ val volumeConf = KubernetesEmptyDirVolumeConf(None, None)
+ val emptyDirSpec = KubernetesVolumeSpec(
+ "testEmptyDir",
+ "/tmp/foo",
+ "foo",
+ true,
+ KubernetesEmptyDirVolumeConf(None, None)
+ )
+ val pvcSpec = KubernetesVolumeSpec(
+ "testPVC",
+ "/tmp/bar",
+ "bar",
+ true,
+ KubernetesEmptyDirVolumeConf(None, None)
+ )
+ val kubernetesConf = emptyKubernetesConf.copy(
+ roleVolumes = emptyDirSpec :: pvcSpec :: Nil)
+ val step = new MountVolumesFeatureStep(kubernetesConf)
+ val configuredPod = step.configurePod(SparkPod.initialPod())
+
+ assert(configuredPod.pod.getSpec.getVolumes.size() === 2)
+ val mounts = configuredPod.container.getVolumeMounts
+ assert(mounts.size() === 2)
+ assert(mounts.get(0).getName === "testEmptyDir")
+ assert(mounts.get(0).getMountPath === "/tmp/foo")
+ assert(mounts.get(0).getSubPath === "foo")
+ assert(mounts.get(1).getName === "testPVC")
+ assert(mounts.get(1).getMountPath === "/tmp/bar")
+ assert(mounts.get(1).getSubPath === "bar")
+ }
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
index 0cfb2127a5d26..24aebef8633ad 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
@@ -177,6 +177,42 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite {
val volumeSpec = KubernetesVolumeSpec(
"volume",
"/tmp",
+ "",
+ false,
+ KubernetesHostPathVolumeConf("/path"))
+ val conf = KubernetesConf(
+ new SparkConf(false),
+ KubernetesDriverSpecificConf(
+ JavaMainAppResource(None),
+ "test-app",
+ "main",
+ Seq.empty),
+ "prefix",
+ "appId",
+ None,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ Map.empty,
+ volumeSpec :: Nil,
+ hadoopConfSpec = None)
+ validateStepTypesApplied(
+ builderUnderTest.buildFromFeatures(conf),
+ BASIC_STEP_TYPE,
+ CREDENTIALS_STEP_TYPE,
+ SERVICE_STEP_TYPE,
+ LOCAL_DIRS_STEP_TYPE,
+ LOCAL_FILES_STEP_TYPE,
+ MOUNT_VOLUMES_STEP_TYPE,
+ DRIVER_CMD_STEP_TYPE)
+ }
+
+ test("Apply volumes step if a mount subpath is present.") {
+ val volumeSpec = KubernetesVolumeSpec(
+ "volume",
+ "/tmp",
+ "foo",
false,
KubernetesHostPathVolumeConf("/path"))
val conf = KubernetesConf(
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
index 29007431d4b52..b3e75b6661926 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
@@ -135,6 +135,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite {
val volumeSpec = KubernetesVolumeSpec(
"volume",
"/tmp",
+ "",
false,
KubernetesHostPathVolumeConf("/checkpoint"))
val conf = KubernetesConf(
diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile
index a5850dad21d06..787b8cc52021e 100644
--- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile
+++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile
@@ -17,11 +17,6 @@
FROM openjdk:8-alpine
-ARG spark_jars=jars
-ARG example_jars=examples/jars
-ARG img_path=kubernetes/dockerfiles
-ARG k8s_tests=kubernetes/tests
-
# Before building the docker image, first build and make a Spark distribution following
# the instructions in http://spark.apache.org/docs/latest/building-spark.html.
# If this docker file is being used in the context of building your images from a Spark
@@ -39,10 +34,13 @@ RUN set -ex && \
ln -sv /bin/bash /bin/sh && \
chgrp root /etc/passwd && chmod ug+rw /etc/passwd
-COPY ${spark_jars} /opt/spark/jars
+COPY jars /opt/spark/jars
COPY bin /opt/spark/bin
COPY sbin /opt/spark/sbin
-COPY ${img_path}/spark/entrypoint.sh /opt/
+COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/
+COPY examples /opt/spark/examples
+COPY kubernetes/tests /opt/spark/tests
+COPY data /opt/spark/data
ENV SPARK_HOME /opt/spark
diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md
index 64f8e77597eba..73fc0581d64f5 100644
--- a/resource-managers/kubernetes/integration-tests/README.md
+++ b/resource-managers/kubernetes/integration-tests/README.md
@@ -107,7 +107,7 @@ properties to Maven. For example:
mvn integration-test -am -pl :spark-kubernetes-integration-tests_2.11 \
-Pkubernetes -Pkubernetes-integration-tests \
- -Phadoop-2.7 -Dhadoop.version=2.7.3 \
+ -Phadoop-2.7 -Dhadoop.version=2.7.4 \
-Dspark.kubernetes.test.sparkTgz=spark-3.0.0-SNAPSHOT-bin-example.tgz \
-Dspark.kubernetes.test.imageTag=sometag \
-Dspark.kubernetes.test.imageRepo=docker.io/somerepo \
diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml
index 301b6fe8eee56..17af0e03f2bbb 100644
--- a/resource-managers/kubernetes/integration-tests/pom.xml
+++ b/resource-managers/kubernetes/integration-tests/pom.xml
@@ -19,12 +19,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../../pom.xml
- spark-kubernetes-integration-tests_2.11
+ spark-kubernetes-integration-tests_2.121.3.01.4.0
diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh
index a4a9f5b7da131..36e30d7b2cffb 100755
--- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh
+++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh
@@ -72,10 +72,16 @@ then
IMAGE_TAG=$(uuidgen);
cd $UNPACKED_SPARK_TGZ
+ # Build PySpark image
+ LANGUAGE_BINDING_BUILD_ARGS="-p $UNPACKED_SPARK_TGZ/kubernetes/dockerfiles/spark/bindings/python/Dockerfile"
+
+ # Build SparkR image
+ LANGUAGE_BINDING_BUILD_ARGS="$LANGUAGE_BINDING_BUILD_ARGS -R $UNPACKED_SPARK_TGZ/kubernetes/dockerfiles/spark/bindings/R/Dockerfile"
+
case $DEPLOY_MODE in
cloud)
# Build images
- $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build
+ $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build
# Push images appropriately
if [[ $IMAGE_REPO == gcr.io* ]] ;
@@ -89,13 +95,13 @@ then
docker-for-desktop)
# Only need to build as this will place it in our local Docker repo which is all
# we need for Docker for Desktop to work so no need to also push
- $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG build
+ $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build
;;
minikube)
# Only need to build and if we do this with the -m option for minikube we will
# build the images directly using the minikube Docker daemon so no need to push
- $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG build
+ $UNPACKED_SPARK_TGZ/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build
;;
*)
echo "Unrecognized deploy mode $DEPLOY_MODE" && exit 1
diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml
index 9585bdfafdcf4..7b3aad4d6ce35 100644
--- a/resource-managers/mesos/pom.xml
+++ b/resource-managers/mesos/pom.xml
@@ -19,12 +19,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-mesos_2.11
+ spark-mesos_2.12jarSpark Project Mesos
diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml
index e55b814be8465..d18df9955bb1f 100644
--- a/resource-managers/yarn/pom.xml
+++ b/resource-managers/yarn/pom.xml
@@ -19,12 +19,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-yarn_2.11
+ spark-yarn_2.12jarSpark Project YARN
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index ebdcf45603cea..9497530805c1a 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -20,7 +20,6 @@ package org.apache.spark.deploy.yarn
import java.util.Collections
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicInteger
-import java.util.regex.Pattern
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -598,13 +597,21 @@ private[yarn] class YarnAllocator(
(false, s"Container ${containerId}${onHostStr} was preempted.")
// Should probably still count memory exceeded exit codes towards task failures
case VMEM_EXCEEDED_EXIT_CODE =>
- (true, memLimitExceededLogMessage(
- completedContainer.getDiagnostics,
- VMEM_EXCEEDED_PATTERN))
+ val vmemExceededPattern = raw"$MEM_REGEX of $MEM_REGEX virtual memory used".r
+ val diag = vmemExceededPattern.findFirstIn(completedContainer.getDiagnostics)
+ .map(_.concat(".")).getOrElse("")
+ val message = "Container killed by YARN for exceeding virtual memory limits. " +
+ s"$diag Consider boosting ${EXECUTOR_MEMORY_OVERHEAD.key} or boosting " +
+ s"${YarnConfiguration.NM_VMEM_PMEM_RATIO} or disabling " +
+ s"${YarnConfiguration.NM_VMEM_CHECK_ENABLED} because of YARN-4714."
+ (true, message)
case PMEM_EXCEEDED_EXIT_CODE =>
- (true, memLimitExceededLogMessage(
- completedContainer.getDiagnostics,
- PMEM_EXCEEDED_PATTERN))
+ val pmemExceededPattern = raw"$MEM_REGEX of $MEM_REGEX physical memory used".r
+ val diag = pmemExceededPattern.findFirstIn(completedContainer.getDiagnostics)
+ .map(_.concat(".")).getOrElse("")
+ val message = "Container killed by YARN for exceeding physical memory limits. " +
+ s"$diag Consider boosting ${EXECUTOR_MEMORY_OVERHEAD.key}."
+ (true, message)
case _ =>
// all the failures which not covered above, like:
// disk failure, kill by app master or resource manager, ...
@@ -735,18 +742,6 @@ private[yarn] class YarnAllocator(
private object YarnAllocator {
val MEM_REGEX = "[0-9.]+ [KMG]B"
- val PMEM_EXCEEDED_PATTERN =
- Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used")
- val VMEM_EXCEEDED_PATTERN =
- Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used")
val VMEM_EXCEEDED_EXIT_CODE = -103
val PMEM_EXCEEDED_EXIT_CODE = -104
-
- def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = {
- val matcher = pattern.matcher(diagnostics)
- val diag = if (matcher.find()) " " + matcher.group() + "." else ""
- s"Container killed by YARN for exceeding memory limits. $diag " +
- "Consider boosting spark.yarn.executor.memoryOverhead or " +
- "disabling yarn.nodemanager.vmem-check-enabled because of YARN-4714."
- }
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala
index 4ed285230ff81..7d15f0e2fbac8 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala
@@ -107,7 +107,7 @@ private[spark] class SchedulerExtensionServices extends SchedulerExtensionServic
services = sparkContext.conf.get(SCHEDULER_SERVICES).map { sClass =>
val instance = Utils.classForName(sClass)
- .newInstance()
+ .getConstructor().newInstance()
.asInstanceOf[SchedulerExtensionService]
// bind this service
instance.start(binding)
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index 35299166d9814..b61e7df4420ef 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -29,7 +29,6 @@ import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfterEach, Matchers}
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
-import org.apache.spark.deploy.yarn.YarnAllocator._
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.rpc.RpcEndpointRef
@@ -116,8 +115,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
}
def createContainer(host: String, resource: Resource = containerResource): Container = {
- // When YARN 2.6+ is required, avoid deprecation by using version with long second arg
- val containerId = ContainerId.newInstance(appAttemptId, containerNum)
+ val containerId = ContainerId.newContainerId(appAttemptId, containerNum)
containerNum += 1
val nodeId = NodeId.newInstance(host, 1000)
Container.newInstance(containerId, nodeId, "", resource, RM_REQUEST_PRIORITY, null)
@@ -377,17 +375,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava)
}
- test("memory exceeded diagnostic regexes") {
- val diagnostics =
- "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " +
- "beyond physical memory limits. Current usage: 2.1 MB of 2 GB physical memory used; " +
- "5.8 GB of 4.2 GB virtual memory used. Killing container."
- val vmemMsg = memLimitExceededLogMessage(diagnostics, VMEM_EXCEEDED_PATTERN)
- val pmemMsg = memLimitExceededLogMessage(diagnostics, PMEM_EXCEEDED_PATTERN)
- assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used."))
- assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used."))
- }
-
test("window based failure executor counting") {
sparkConf.set("spark.yarn.executor.failuresValidityInterval", "100s")
val handler = createAllocator(4)
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 36a73e3362218..4892819ae9973 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -240,6 +240,17 @@ This file is divided into 3 sections:
]]>
+
+ throw new \w+Error\(
+
+
+
JavaConversions
diff --git a/spark-docker-image-generator/src/test/resources/ExpectedDockerfile b/spark-docker-image-generator/src/test/resources/ExpectedDockerfile
index 2e0613cd2a826..31ec83d3db601 100644
--- a/spark-docker-image-generator/src/test/resources/ExpectedDockerfile
+++ b/spark-docker-image-generator/src/test/resources/ExpectedDockerfile
@@ -17,11 +17,6 @@
FROM fabric8/java-centos-openjdk8-jdk:latest
-ARG spark_jars=jars
-ARG example_jars=examples/jars
-ARG img_path=kubernetes/dockerfiles
-ARG k8s_tests=kubernetes/tests
-
# Before building the docker image, first build and make a Spark distribution following
# the instructions in http://spark.apache.org/docs/latest/building-spark.html.
# If this docker file is being used in the context of building your images from a Spark
@@ -39,10 +34,13 @@ RUN set -ex && \
ln -sv /bin/bash /bin/sh && \
chgrp root /etc/passwd && chmod ug+rw /etc/passwd
-COPY ${spark_jars} /opt/spark/jars
+COPY jars /opt/spark/jars
COPY bin /opt/spark/bin
COPY sbin /opt/spark/sbin
-COPY ${img_path}/spark/entrypoint.sh /opt/
+COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/
+COPY examples /opt/spark/examples
+COPY kubernetes/tests /opt/spark/tests
+COPY data /opt/spark/data
ENV SPARK_HOME /opt/spark
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 16ecebf159c1f..20cc5d03fbe52 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-catalyst_2.11
+ spark-catalyst_2.12jarSpark Project Catalysthttp://spark.apache.org/
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index e2d34d1650ddc..5e732edb17baa 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -691,6 +691,7 @@ namedWindow
windowSpec
: name=identifier #windowRef
+ | '('name=identifier')' #windowRef
| '('
( CLUSTER BY partition+=expression (',' partition+=expression)*
| ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)?
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java
index 2ce1fdcbf56ae..0258e66ffb6e5 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/RowFactory.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
/**
@@ -25,7 +25,7 @@
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
public class RowFactory {
/**
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java
index 460513816dfd9..6344cf18c11b8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatch.java
@@ -20,6 +20,7 @@
import java.io.IOException;
import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.SparkOutOfMemoryError;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.memory.MemoryBlock;
@@ -126,7 +127,7 @@ public final void close() {
private boolean acquirePage(long requiredSize) {
try {
page = allocatePage(requiredSize);
- } catch (OutOfMemoryError e) {
+ } catch (SparkOutOfMemoryError e) {
logger.warn("Failed to allocate page ({} bytes).", requiredSize);
return false;
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 9002abdcfd474..d5f679fe23d48 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -334,17 +334,11 @@ public void setLong(int ordinal, long value) {
}
public void setFloat(int ordinal, float value) {
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
assertIndexIsValid(ordinal);
Platform.putFloat(baseObject, getElementOffset(ordinal, 4), value);
}
public void setDouble(int ordinal, double value) {
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
assertIndexIsValid(ordinal);
Platform.putDouble(baseObject, getElementOffset(ordinal, 8), value);
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index a76e6ef8c91c1..9bf9452855f5f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -224,9 +224,6 @@ public void setLong(int ordinal, long value) {
public void setDouble(int ordinal, double value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
Platform.putDouble(baseObject, getFieldOffset(ordinal), value);
}
@@ -255,9 +252,6 @@ public void setByte(int ordinal, byte value) {
public void setFloat(int ordinal, float value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
Platform.putFloat(baseObject, getFieldOffset(ordinal), value);
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 2781655002000..95263a0da95a8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -199,16 +199,10 @@ protected final void writeLong(long offset, long value) {
}
protected final void writeFloat(long offset, float value) {
- if (Float.isNaN(value)) {
- value = Float.NaN;
- }
Platform.putFloat(getBuffer(), offset, value);
}
protected final void writeDouble(long offset, double value) {
- if (Double.isNaN(value)) {
- value = Double.NaN;
- }
Platform.putDouble(getBuffer(), offset, value);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 1b2f5eee5ccdd..5395e4035e680 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -50,7 +50,7 @@ public final class UnsafeExternalRowSorter {
private long numRowsInserted = 0;
private final StructType schema;
- private final PrefixComputer prefixComputer;
+ private final UnsafeExternalRowSorter.PrefixComputer prefixComputer;
private final UnsafeExternalSorter sorter;
public abstract static class PrefixComputer {
@@ -74,7 +74,7 @@ public static UnsafeExternalRowSorter createWithRecordComparator(
StructType schema,
Supplier recordComparatorSupplier,
PrefixComparator prefixComparator,
- PrefixComputer prefixComputer,
+ UnsafeExternalRowSorter.PrefixComputer prefixComputer,
long pageSizeBytes,
boolean canUseRadixSort) throws IOException {
return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
@@ -85,7 +85,7 @@ public static UnsafeExternalRowSorter create(
StructType schema,
Ordering ordering,
PrefixComparator prefixComparator,
- PrefixComputer prefixComputer,
+ UnsafeExternalRowSorter.PrefixComputer prefixComputer,
long pageSizeBytes,
boolean canUseRadixSort) throws IOException {
Supplier recordComparatorSupplier =
@@ -98,9 +98,9 @@ private UnsafeExternalRowSorter(
StructType schema,
Supplier recordComparatorSupplier,
PrefixComparator prefixComparator,
- PrefixComputer prefixComputer,
+ UnsafeExternalRowSorter.PrefixComputer prefixComputer,
long pageSizeBytes,
- boolean canUseRadixSort) throws IOException {
+ boolean canUseRadixSort) {
this.schema = schema;
this.prefixComputer = prefixComputer;
final SparkEnv sparkEnv = SparkEnv.get();
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java
index 5f1032d1229da..5f6a46f2b8e89 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java
@@ -17,8 +17,8 @@
package org.apache.spark.sql.streaming;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.annotation.Experimental;
-import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.catalyst.plans.logical.*;
/**
@@ -29,7 +29,7 @@
* @since 2.2.0
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
public class GroupStateTimeout {
/**
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java
index 470c128ee6c3d..a3d72a1f5d49f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.streaming;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes;
/**
@@ -26,7 +26,7 @@
*
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
public class OutputMode {
/**
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java
index 0f8570fe470bd..d786374f69e20 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java
@@ -19,7 +19,7 @@
import java.util.*;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* To get/create specific data type, users should use singleton objects and factory methods
@@ -27,7 +27,7 @@
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
public class DataTypes {
/**
* Gets the StringType object.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java
index 1290614a3207d..a54398324fc66 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java
@@ -20,7 +20,7 @@
import java.lang.annotation.*;
import org.apache.spark.annotation.DeveloperApi;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* ::DeveloperApi::
@@ -31,7 +31,7 @@
@DeveloperApi
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
-@InterfaceStability.Evolving
+@Evolving
public @interface SQLUserDefinedType {
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 50ee6cd4085ea..f5c87677ab9eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -17,16 +17,15 @@
package org.apache.spark.sql
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-
/**
* Thrown when a query fails to analyze, usually because the query itself is invalid.
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class AnalysisException protected[sql] (
val message: String,
val line: Option[Int] = None,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 7b02317b8538f..9853a4fcc2f9d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -20,10 +20,9 @@ package org.apache.spark.sql
import scala.annotation.implicitNotFound
import scala.reflect.ClassTag
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.{Evolving, Experimental}
import org.apache.spark.sql.types._
-
/**
* :: Experimental ::
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
@@ -67,7 +66,7 @@ import org.apache.spark.sql.types._
* @since 1.6.0
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " +
"store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " +
"classes) are supported by importing spark.implicits._ Support for serializing other types " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
index 8a30c81912fe9..42b865c027205 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -22,7 +22,7 @@ import java.lang.reflect.Modifier
import scala.reflect.{classTag, ClassTag}
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.{Evolving, Experimental}
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast}
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types._
* @since 1.6.0
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
object Encoders {
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 180c2d130074e..e12bf9616e2de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
import scala.util.hashing.MurmurHash3
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.StructType
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
object Row {
/**
* This method can be used to extract fields from a [[Row]] object in a pattern match. Example:
@@ -124,7 +124,7 @@ object Row {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
trait Row extends Serializable {
/** Number of elements in the Row. */
def size: Int = length
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 8ef8b2be6939c..311060e5961cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -73,10 +73,10 @@ object JavaTypeInference {
: (DataType, Boolean) = {
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
- (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
+ (c.getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance(), true)
case c: Class[_] if UDTRegistration.exists(c.getName) =>
- val udt = UDTRegistration.getUDTFor(c.getName).get.newInstance()
+ val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor().newInstance()
.asInstanceOf[UserDefinedType[_ >: Null]]
(udt, true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala
new file mode 100644
index 0000000000000..244081cd160b6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/QueryPlanningTracker.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.util.BoundedPriorityQueue
+
+
+/**
+ * A simple utility for tracking runtime and associated stats in query planning.
+ *
+ * There are two separate concepts we track:
+ *
+ * 1. Phases: These are broad scope phases in query planning, as listed below, i.e. analysis,
+ * optimizationm and physical planning (just planning).
+ *
+ * 2. Rules: These are the individual Catalyst rules that we track. In addition to time, we also
+ * track the number of invocations and effective invocations.
+ */
+object QueryPlanningTracker {
+
+ // Define a list of common phases here.
+ val PARSING = "parsing"
+ val ANALYSIS = "analysis"
+ val OPTIMIZATION = "optimization"
+ val PLANNING = "planning"
+
+ class RuleSummary(
+ var totalTimeNs: Long, var numInvocations: Long, var numEffectiveInvocations: Long) {
+
+ def this() = this(totalTimeNs = 0, numInvocations = 0, numEffectiveInvocations = 0)
+
+ override def toString: String = {
+ s"RuleSummary($totalTimeNs, $numInvocations, $numEffectiveInvocations)"
+ }
+ }
+
+ /**
+ * A thread local variable to implicitly pass the tracker around. This assumes the query planner
+ * is single-threaded, and avoids passing the same tracker context in every function call.
+ */
+ private val localTracker = new ThreadLocal[QueryPlanningTracker]() {
+ override def initialValue: QueryPlanningTracker = null
+ }
+
+ /** Returns the current tracker in scope, based on the thread local variable. */
+ def get: Option[QueryPlanningTracker] = Option(localTracker.get())
+
+ /** Sets the current tracker for the execution of function f. We assume f is single-threaded. */
+ def withTracker[T](tracker: QueryPlanningTracker)(f: => T): T = {
+ val originalTracker = localTracker.get()
+ localTracker.set(tracker)
+ try f finally { localTracker.set(originalTracker) }
+ }
+}
+
+
+class QueryPlanningTracker {
+
+ import QueryPlanningTracker._
+
+ // Mapping from the name of a rule to a rule's summary.
+ // Use a Java HashMap for less overhead.
+ private val rulesMap = new java.util.HashMap[String, RuleSummary]
+
+ // From a phase to time in ns.
+ private val phaseToTimeNs = new java.util.HashMap[String, Long]
+
+ /** Measure the runtime of function f, and add it to the time for the specified phase. */
+ def measureTime[T](phase: String)(f: => T): T = {
+ val startTime = System.nanoTime()
+ val ret = f
+ val timeTaken = System.nanoTime() - startTime
+ phaseToTimeNs.put(phase, phaseToTimeNs.getOrDefault(phase, 0) + timeTaken)
+ ret
+ }
+
+ /**
+ * Record a specific invocation of a rule.
+ *
+ * @param rule name of the rule
+ * @param timeNs time taken to run this invocation
+ * @param effective whether the invocation has resulted in a plan change
+ */
+ def recordRuleInvocation(rule: String, timeNs: Long, effective: Boolean): Unit = {
+ var s = rulesMap.get(rule)
+ if (s eq null) {
+ s = new RuleSummary
+ rulesMap.put(rule, s)
+ }
+
+ s.totalTimeNs += timeNs
+ s.numInvocations += 1
+ s.numEffectiveInvocations += (if (effective) 1 else 0)
+ }
+
+ // ------------ reporting functions below ------------
+
+ def rules: Map[String, RuleSummary] = rulesMap.asScala.toMap
+
+ def phases: Map[String, Long] = phaseToTimeNs.asScala.toMap
+
+ /**
+ * Returns the top k most expensive rules (as measured by time). If k is larger than the rules
+ * seen so far, return all the rules. If there is no rule seen so far or k <= 0, return empty seq.
+ */
+ def topRulesByTime(k: Int): Seq[(String, RuleSummary)] = {
+ if (k <= 0) {
+ Seq.empty
+ } else {
+ val orderingByTime: Ordering[(String, RuleSummary)] = Ordering.by(e => e._2.totalTimeNs)
+ val q = new BoundedPriorityQueue(k)(orderingByTime)
+ rulesMap.asScala.foreach(q.+=)
+ q.toSeq.sortBy(r => -r._2.totalTimeNs)
+ }
+ }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 912744eab6a3a..c8542d0f2f7de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -357,7 +357,8 @@ object ScalaReflection extends ScalaReflection {
)
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
- val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
+ getConstructor().newInstance()
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
@@ -365,8 +366,8 @@ object ScalaReflection extends ScalaReflection {
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
- .asInstanceOf[UserDefinedType[_]]
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
+ newInstance().asInstanceOf[UserDefinedType[_]]
val obj = NewInstance(
udt.getClass,
Nil,
@@ -601,7 +602,7 @@ object ScalaReflection extends ScalaReflection {
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t)
- .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ .getAnnotation(classOf[SQLUserDefinedType]).udt().getConstructor().newInstance()
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
@@ -609,8 +610,8 @@ object ScalaReflection extends ScalaReflection {
Invoke(obj, "serialize", udt, inputObject :: Nil)
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
- .asInstanceOf[UserDefinedType[_]]
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
+ newInstance().asInstanceOf[UserDefinedType[_]]
val obj = NewInstance(
udt.getClass,
Nil,
@@ -721,11 +722,12 @@ object ScalaReflection extends ScalaReflection {
// Null type would wrongly match the first of them, which is Option as of now
case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true)
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
- val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
+ getConstructor().newInstance()
Schema(udt, nullable = true)
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
- .asInstanceOf[UserDefinedType[_]]
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
+ newInstance().asInstanceOf[UserDefinedType[_]]
Schema(udt, nullable = true)
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
@@ -786,12 +788,37 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * Finds an accessible constructor with compatible parameters. This is a more flexible search
- * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
- * matching constructor is returned. Otherwise, it returns `None`.
+ * Finds an accessible constructor with compatible parameters. This is a more flexible search than
+ * the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
+ * matching constructor is returned if it exists. Otherwise, we check for additional compatible
+ * constructors defined in the companion object as `apply` methods. Otherwise, it returns `None`.
*/
- def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = {
- Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*))
+ def findConstructor[T](cls: Class[T], paramTypes: Seq[Class[_]]): Option[Seq[AnyRef] => T] = {
+ Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) match {
+ case Some(c) => Some(x => c.newInstance(x: _*).asInstanceOf[T])
+ case None =>
+ val companion = mirror.staticClass(cls.getName).companion
+ val moduleMirror = mirror.reflectModule(companion.asModule)
+ val applyMethods = companion.asTerm.typeSignature
+ .member(universe.TermName("apply")).asTerm.alternatives
+ applyMethods.find { method =>
+ val params = method.typeSignature.paramLists.head
+ // Check that the needed params are the same length and of matching types
+ params.size == paramTypes.tail.size &&
+ params.zip(paramTypes.tail).forall { case(ps, pc) =>
+ ps.typeSignature.typeSymbol == mirror.classSymbol(pc)
+ }
+ }.map { applyMethodSymbol =>
+ val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size
+ val instanceMirror = mirror.reflect(moduleMirror.instance)
+ val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod)
+ (_args: Seq[AnyRef]) => {
+ // Drop the "outer" argument if it is provided
+ val args = if (_args.size == expectedArgsCount) _args else _args.tail
+ method.apply(args: _*).asInstanceOf[T]
+ }
+ }
+ }
}
/**
@@ -971,8 +998,19 @@ trait ScalaReflection extends Logging {
}
}
+ /**
+ * If our type is a Scala trait it may have a companion object that
+ * only defines a constructor via `apply` method.
+ */
+ private def getCompanionConstructor(tpe: Type): Symbol = {
+ tpe.typeSymbol.asClass.companion.asTerm.typeSignature.member(universe.TermName("apply"))
+ }
+
protected def constructParams(tpe: Type): Seq[Symbol] = {
- val constructorSymbol = tpe.dealias.member(termNames.CONSTRUCTOR)
+ val constructorSymbol = tpe.member(termNames.CONSTRUCTOR) match {
+ case NoSymbol => getCompanionConstructor(tpe)
+ case sym => sym
+ }
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramLists
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c2d22c5e7ce60..b977fa07db5c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -102,16 +102,18 @@ class Analyzer(
this(catalog, conf, conf.optimizerMaxIterations)
}
- def executeAndCheck(plan: LogicalPlan): LogicalPlan = AnalysisHelper.markInAnalyzer {
- val analyzed = execute(plan)
- try {
- checkAnalysis(analyzed)
- analyzed
- } catch {
- case e: AnalysisException =>
- val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed))
- ae.setStackTrace(e.getStackTrace)
- throw ae
+ def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = {
+ AnalysisHelper.markInAnalyzer {
+ val analyzed = executeAndTrack(plan, tracker)
+ try {
+ checkAnalysis(analyzed)
+ analyzed
+ } catch {
+ case e: AnalysisException =>
+ val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed))
+ ae.setStackTrace(e.getStackTrace)
+ throw ae
+ }
}
}
@@ -824,7 +826,8 @@ class Analyzer(
}
private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
- attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier)
+ val exprId = attrMap.getOrElse(attr, attr).exprId
+ attr.withExprId(exprId)
}
/**
@@ -870,7 +873,7 @@ class Analyzer(
private def dedupOuterReferencesInSubquery(
plan: LogicalPlan,
attrMap: AttributeMap[Attribute]): LogicalPlan = {
- plan resolveOperatorsDown { case currentFragment =>
+ plan transformDown { case currentFragment =>
currentFragment transformExpressions {
case OuterReference(a: Attribute) =>
OuterReference(dedupAttr(a, attrMap))
@@ -952,6 +955,12 @@ class Analyzer(
// rule: ResolveDeserializer.
case plan if containsDeserializer(plan.expressions) => plan
+ // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of
+ // `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute
+ // names leading to ambiguous references exception.
+ case a @ Aggregate(groupingExprs, aggExprs, appendColumns: AppendColumns) =>
+ a.mapExpressions(resolve(_, appendColumns))
+
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q.mapExpressions(resolve(_, q))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 72ac80e0a0a18..133fa119b7aa6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -181,8 +181,9 @@ object TypeCoercion {
}
/**
- * The method finds a common type for data types that differ only in nullable, containsNull
- * and valueContainsNull flags. If the input types are too different, None is returned.
+ * The method finds a common type for data types that differ only in nullable flags, including
+ * `nullable`, `containsNull` of [[ArrayType]] and `valueContainsNull` of [[MapType]].
+ * If the input types are different besides nullable flags, None is returned.
*/
def findCommonTypeDifferentOnlyInNullFlags(t1: DataType, t2: DataType): Option[DataType] = {
if (t1 == t2) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 857cf382b8f2c..36cad3cf74785 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -112,6 +112,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
override def withQualifier(newQualifier: Seq[String]): UnresolvedAttribute = this
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
override def withMetadata(newMetadata: Metadata): Attribute = this
+ override def withExprId(newExprId: ExprId): UnresolvedAttribute = this
override def toString: String = s"'$name"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index c11b444212946..b6771ec4dffe9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -1134,7 +1134,8 @@ class SessionCatalog(
if (clsForUDAF.isAssignableFrom(clazz)) {
val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF")
val e = cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int])
- .newInstance(input, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1))
+ .newInstance(input,
+ clazz.getConstructor().newInstance().asInstanceOf[Object], Int.box(1), Int.box(1))
.asInstanceOf[ImplicitCastInputTypes]
// Check input argument size
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
index cdaaa172e8367..94bdb72d675d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
@@ -25,6 +25,7 @@ import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.internal.SQLConf
class CSVOptions(
@transient val parameters: CaseInsensitiveMap[String],
@@ -33,11 +34,22 @@ class CSVOptions(
defaultColumnNameOfCorruptRecord: String)
extends Logging with Serializable {
+ def this(
+ parameters: Map[String, String],
+ columnPruning: Boolean,
+ defaultTimeZoneId: String) = {
+ this(
+ CaseInsensitiveMap(parameters),
+ columnPruning,
+ defaultTimeZoneId,
+ SQLConf.get.columnNameOfCorruptRecord)
+ }
+
def this(
parameters: Map[String, String],
columnPruning: Boolean,
defaultTimeZoneId: String,
- defaultColumnNameOfCorruptRecord: String = "") = {
+ defaultColumnNameOfCorruptRecord: String) = {
this(
CaseInsensitiveMap(parameters),
columnPruning,
@@ -131,13 +143,16 @@ class CSVOptions(
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
+ // A language tag in IETF BCP 47 format
+ val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US)
+
// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
val dateFormat: FastDateFormat =
- FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US)
+ FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale)
val timestampFormat: FastDateFormat =
FastDateFormat.getInstance(
- parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
+ parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale)
val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
@@ -177,6 +192,20 @@ class CSVOptions(
*/
val emptyValueInWrite = emptyValue.getOrElse("\"\"")
+ /**
+ * A string between two consecutive JSON records.
+ */
+ val lineSeparator: Option[String] = parameters.get("lineSep").map { sep =>
+ require(sep.nonEmpty, "'lineSep' cannot be an empty string.")
+ require(sep.length == 1, "'lineSep' can contain only 1 character.")
+ sep
+ }
+
+ val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep =>
+ lineSep.getBytes(charset)
+ }
+ val lineSeparatorInWrite: Option[String] = lineSeparator
+
def asWriterSettings: CsvWriterSettings = {
val writerSettings = new CsvWriterSettings()
val format = writerSettings.getFormat
@@ -185,6 +214,8 @@ class CSVOptions(
format.setQuoteEscape(escape)
charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping)
format.setComment(comment)
+ lineSeparatorInWrite.foreach(format.setLineSeparator)
+
writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite)
writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite)
writerSettings.setNullValue(nullValue)
@@ -201,8 +232,10 @@ class CSVOptions(
format.setDelimiter(delimiter)
format.setQuote(quote)
format.setQuoteEscape(escape)
+ lineSeparator.foreach(format.setLineSeparator)
charToEscapeQuoteEscaping.foreach(format.setCharToEscapeQuoteEscaping)
format.setComment(comment)
+
settings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceInRead)
settings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceInRead)
settings.setReadInputOnSeparateThread(false)
@@ -212,7 +245,10 @@ class CSVOptions(
settings.setEmptyValue(emptyValueInRead)
settings.setMaxCharsPerColumn(maxCharsPerColumn)
settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER)
- settings.setLineSeparatorDetectionEnabled(multiLine == true)
+ settings.setLineSeparatorDetectionEnabled(lineSeparatorInRead.isEmpty && multiLine)
+ lineSeparatorInRead.foreach { _ =>
+ settings.setNormalizeLineEndingsWithinQuotes(!multiLine)
+ }
settings
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
index 46ed58ed92830..ed196935e357f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
@@ -271,11 +271,12 @@ private[sql] object UnivocityParser {
def tokenizeStream(
inputStream: InputStream,
shouldDropHeader: Boolean,
- tokenizer: CsvParser): Iterator[Array[String]] = {
+ tokenizer: CsvParser,
+ encoding: String): Iterator[Array[String]] = {
val handleHeader: () => Unit =
() => if (shouldDropHeader) tokenizer.parseNext
- convertStream(inputStream, tokenizer, handleHeader)(tokens => tokens)
+ convertStream(inputStream, tokenizer, handleHeader, encoding)(tokens => tokens)
}
/**
@@ -297,7 +298,7 @@ private[sql] object UnivocityParser {
val handleHeader: () => Unit =
() => headerChecker.checkHeaderColumnNames(tokenizer)
- convertStream(inputStream, tokenizer, handleHeader) { tokens =>
+ convertStream(inputStream, tokenizer, handleHeader, parser.options.charset) { tokens =>
safeParser.parse(tokens)
}.flatten
}
@@ -305,9 +306,10 @@ private[sql] object UnivocityParser {
private def convertStream[T](
inputStream: InputStream,
tokenizer: CsvParser,
- handleHeader: () => Unit)(
+ handleHeader: () => Unit,
+ encoding: String)(
convert: Array[String] => T) = new Iterator[T] {
- tokenizer.beginParsing(inputStream)
+ tokenizer.beginParsing(inputStream, encoding)
// We can handle header here since here the stream is open.
handleHeader()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 592520c59a761..589e215c55e44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -49,15 +49,6 @@ object ExpressionEncoder {
val mirror = ScalaReflection.mirror
val tpe = typeTag[T].in(mirror).tpe
- if (ScalaReflection.optionOfProductType(tpe)) {
- throw new UnsupportedOperationException(
- "Cannot create encoder for Option of Product type, because Product type is represented " +
- "as a row, and the entire row can not be null in Spark SQL like normal databases. " +
- "You can wrap your type with Tuple1 if you do want top level null Product objects, " +
- "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " +
- "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`")
- }
-
val cls = mirror.runtimeClass(tpe)
val serializer = ScalaReflection.serializerForType(tpe)
val deserializer = ScalaReflection.deserializerForType(tpe)
@@ -198,7 +189,7 @@ case class ExpressionEncoder[T](
val serializer: Seq[NamedExpression] = {
val clsName = Utils.getSimpleName(clsTag.runtimeClass)
- if (isSerializedAsStruct) {
+ if (isSerializedAsStructForTopLevel) {
val nullSafeSerializer = objSerializer.transformUp {
case r: BoundReference =>
// For input object of Product type, we can't encode it to row if it's null, as Spark SQL
@@ -213,6 +204,9 @@ case class ExpressionEncoder[T](
} else {
// For other input objects like primitive, array, map, etc., we construct a struct to wrap
// the serializer which is a column of an row.
+ //
+ // Note: Because Spark SQL doesn't allow top-level row to be null, to encode
+ // top-level Option[Product] type, we make it as a top-level struct column.
CreateNamedStruct(Literal("value") :: objSerializer :: Nil)
}
}.flatten
@@ -226,7 +220,7 @@ case class ExpressionEncoder[T](
* `GetColumnByOrdinal` with corresponding ordinal.
*/
val deserializer: Expression = {
- if (isSerializedAsStruct) {
+ if (isSerializedAsStructForTopLevel) {
// We serialized this kind of objects to root-level row. The input of general deserializer
// is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to
// transform attributes accessors.
@@ -253,10 +247,21 @@ case class ExpressionEncoder[T](
})
/**
- * Returns true if the type `T` is serialized as a struct.
+ * Returns true if the type `T` is serialized as a struct by `objSerializer`.
*/
def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType]
+ /**
+ * If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in
+ * the struct are naturally mapped to top-level columns in a row. In other words, the serialized
+ * struct is flattened to row. But in case of the `T` is also an `Option` type, it can't be
+ * flattened to top-level row, because in Spark SQL top-level row can't be null. This method
+ * returns true if `T` is serialized as struct and is not `Option` type.
+ */
+ def isSerializedAsStructForTopLevel: Boolean = {
+ isSerializedAsStruct && !classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)
+ }
+
// serializer expressions are used to encode an object to a row, while the object is usually an
// intermediate value produced inside an operator, not from the output of the child operator. This
// is quite different from normal expressions, and `AttributeReference` doesn't work here
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
index 040b56cc1caea..89e9071324eff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
@@ -67,4 +67,20 @@ object ExprUtils {
case _ =>
throw new AnalysisException("Must use a map() function for options")
}
+
+ /**
+ * A convenient function for schema validation in datasources supporting
+ * `columnNameOfCorruptRecord` as an option.
+ */
+ def verifyColumnNameOfCorruptRecord(
+ schema: StructType,
+ columnNameOfCorruptRecord: String): Unit = {
+ schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
+ val f = schema(corruptFieldIndex)
+ if (f.dataType != StringType || !f.nullable) {
+ throw new AnalysisException(
+ "The field for corrupt records must be string type and nullable")
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 141fcffcb6fab..2ecec61adb0ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines the basic expression abstract classes in Catalyst.
@@ -40,12 +41,28 @@ import org.apache.spark.util.Utils
* "name(arguments...)", the concrete implementation must be a case class whose constructor
* arguments are all Expressions types. See [[Substring]] for an example.
*
- * There are a few important traits:
+ * There are a few important traits or abstract classes:
*
* - [[Nondeterministic]]: an expression that is not deterministic.
+ * - [[Stateful]]: an expression that contains mutable state. For example, MonotonicallyIncreasingID
+ * and Rand. A stateful expression is always non-deterministic.
* - [[Unevaluable]]: an expression that is not supposed to be evaluated.
* - [[CodegenFallback]]: an expression that does not have code gen implemented and falls back to
* interpreted mode.
+ * - [[NullIntolerant]]: an expression that is null intolerant (i.e. any null input will result in
+ * null output).
+ * - [[NonSQLExpression]]: a common base trait for the expressions that do not have SQL
+ * expressions like representation. For example, `ScalaUDF`, `ScalaUDAF`,
+ * and object `MapObjects` and `Invoke`.
+ * - [[UserDefinedExpression]]: a common base trait for user-defined functions, including
+ * UDF/UDAF/UDTF.
+ * - [[HigherOrderFunction]]: a common base trait for higher order functions that take one or more
+ * (lambda) functions and applies these to some objects. The function
+ * produces a number of variables which can be consumed by some lambda
+ * functions.
+ * - [[NamedExpression]]: An [[Expression]] that is named.
+ * - [[TimeZoneAwareExpression]]: A common base trait for time zone aware expressions.
+ * - [[SubqueryExpression]]: A base interface for expressions that contain a [[LogicalPlan]].
*
* - [[LeafExpression]]: an expression that has no child.
* - [[UnaryExpression]]: an expression that has one child.
@@ -54,12 +71,20 @@ import org.apache.spark.util.Utils
* - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have
* the same output data type.
*
+ * A few important traits used for type coercion rules:
+ * - [[ExpectsInputTypes]]: an expression that has the expected input types. This trait is typically
+ * used by operator expressions (e.g. [[Add]], [[Subtract]]) to define
+ * expected input types without any implicit casting.
+ * - [[ImplicitCastInputTypes]]: an expression that has the expected input types, which can be
+ * implicitly castable using [[TypeCoercion.ImplicitTypeCasts]].
+ * - [[ComplexTypeMergingExpression]]: to resolve output types of the complex expressions
+ * (e.g., [[CaseWhen]]).
*/
abstract class Expression extends TreeNode[Expression] {
/**
* Returns true when an expression is a candidate for static evaluation before the query is
- * executed.
+ * executed. A typical use case: [[org.apache.spark.sql.catalyst.optimizer.ConstantFolding]]
*
* The following conditions are used to determine suitability for constant folding:
* - A [[Coalesce]] is foldable if all of its children are foldable
@@ -72,7 +97,8 @@ abstract class Expression extends TreeNode[Expression] {
/**
* Returns true when the current expression always return the same result for fixed inputs from
- * children.
+ * children. The non-deterministic expressions should not change in number and order. They should
+ * not be evaluated during the query planning.
*
* Note that this means that an expression should be considered as non-deterministic if:
* - it relies on some mutable internal state, or
@@ -237,7 +263,7 @@ abstract class Expression extends TreeNode[Expression] {
override def simpleString: String = toString
- override def toString: String = prettyName + Utils.truncatedString(
+ override def toString: String = prettyName + truncatedString(
flatArguments.toSeq, "(", ", ", ")")
/**
@@ -252,8 +278,9 @@ abstract class Expression extends TreeNode[Expression] {
/**
- * An expression that cannot be evaluated. Some expressions don't live past analysis or optimization
- * time (e.g. Star). This trait is used by those expressions.
+ * An expression that cannot be evaluated. These expressions don't live past analysis or
+ * optimization time (e.g. Star) and should not be evaluated during query planning and
+ * execution.
*/
trait Unevaluable extends Expression {
@@ -724,9 +751,10 @@ abstract class TernaryExpression extends Expression {
}
/**
- * A trait resolving nullable, containsNull, valueContainsNull flags of the output date type.
- * This logic is usually utilized by expressions combining data from multiple child expressions
- * of non-primitive types (e.g. [[CaseWhen]]).
+ * A trait used for resolving nullable flags, including `nullable`, `containsNull` of [[ArrayType]]
+ * and `valueContainsNull` of [[MapType]], containsNull, valueContainsNull flags of the output date
+ * type. This is usually utilized by the expressions (e.g. [[CaseWhen]]) that combine data from
+ * multiple child expressions of non-primitive types.
*/
trait ComplexTypeMergingExpression extends Expression {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index e1d16a2cd38b0..56c2ee6b53fe5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -128,12 +128,10 @@ case class AggregateExpression(
override def nullable: Boolean = aggregateFunction.nullable
override def references: AttributeSet = {
- val childReferences = mode match {
- case Partial | Complete => aggregateFunction.references.toSeq
- case PartialMerge | Final => aggregateFunction.aggBufferAttributes
+ mode match {
+ case Partial | Complete => aggregateFunction.references
+ case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes)
}
-
- AttributeSet(childReferences)
}
override def toString: String = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index b868a0f4fa284..7c8f7cd4315b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1305,7 +1305,7 @@ object CodeGenerator extends Logging {
throw new CompileException(msg, e.getLocation)
}
- (evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize)
+ (evaluator.getClazz().getConstructor().newInstance().asInstanceOf[GeneratedClass], maxCodeSize)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 9a51be6ed5aeb..283fd2a6e9383 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -68,62 +68,55 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
genComparisons(ctx, ordering)
}
+ /**
+ * Creates the variables for ordering based on the given order.
+ */
+ private def createOrderKeys(
+ ctx: CodegenContext,
+ row: String,
+ ordering: Seq[SortOrder]): Seq[ExprCode] = {
+ ctx.INPUT_ROW = row
+ // to use INPUT_ROW we must make sure currentVars is null
+ ctx.currentVars = null
+ ordering.map(_.child.genCode(ctx))
+ }
+
/**
* Generates the code for ordering based on the given order.
*/
def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = {
val oldInputRow = ctx.INPUT_ROW
val oldCurrentVars = ctx.currentVars
- val inputRow = "i"
- ctx.INPUT_ROW = inputRow
- // to use INPUT_ROW we must make sure currentVars is null
- ctx.currentVars = null
-
- val comparisons = ordering.map { order =>
- val eval = order.child.genCode(ctx)
- val asc = order.isAscending
- val isNullA = ctx.freshName("isNullA")
- val primitiveA = ctx.freshName("primitiveA")
- val isNullB = ctx.freshName("isNullB")
- val primitiveB = ctx.freshName("primitiveB")
+ val rowAKeys = createOrderKeys(ctx, "a", ordering)
+ val rowBKeys = createOrderKeys(ctx, "b", ordering)
+ val comparisons = rowAKeys.zip(rowBKeys).zipWithIndex.map { case ((l, r), i) =>
+ val dt = ordering(i).child.dataType
+ val asc = ordering(i).isAscending
+ val nullOrdering = ordering(i).nullOrdering
+ val lRetValue = nullOrdering match {
+ case NullsFirst => "-1"
+ case NullsLast => "1"
+ }
+ val rRetValue = nullOrdering match {
+ case NullsFirst => "1"
+ case NullsLast => "-1"
+ }
s"""
- ${ctx.INPUT_ROW} = a;
- boolean $isNullA;
- ${CodeGenerator.javaType(order.child.dataType)} $primitiveA;
- {
- ${eval.code}
- $isNullA = ${eval.isNull};
- $primitiveA = ${eval.value};
- }
- ${ctx.INPUT_ROW} = b;
- boolean $isNullB;
- ${CodeGenerator.javaType(order.child.dataType)} $primitiveB;
- {
- ${eval.code}
- $isNullB = ${eval.isNull};
- $primitiveB = ${eval.value};
- }
- if ($isNullA && $isNullB) {
- // Nothing
- } else if ($isNullA) {
- return ${
- order.nullOrdering match {
- case NullsFirst => "-1"
- case NullsLast => "1"
- }};
- } else if ($isNullB) {
- return ${
- order.nullOrdering match {
- case NullsFirst => "1"
- case NullsLast => "-1"
- }};
- } else {
- int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)};
- if (comp != 0) {
- return ${if (asc) "comp" else "-comp"};
- }
- }
- """
+ |${l.code}
+ |${r.code}
+ |if (${l.isNull} && ${r.isNull}) {
+ | // Nothing
+ |} else if (${l.isNull}) {
+ | return $lRetValue;
+ |} else if (${r.isNull}) {
+ | return $rRetValue;
+ |} else {
+ | int comp = ${ctx.genComp(dt, l.value, r.value)};
+ | if (comp != 0) {
+ | return ${if (asc) "comp" else "-comp"};
+ | }
+ |}
+ """.stripMargin
}
val code = ctx.splitExpressions(
@@ -133,30 +126,24 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
returnType = "int",
makeSplitFunction = { body =>
s"""
- InternalRow ${ctx.INPUT_ROW} = null; // Holds current row being evaluated.
- $body
- return 0;
- """
+ |$body
+ |return 0;
+ """.stripMargin
},
foldFunctions = { funCalls =>
funCalls.zipWithIndex.map { case (funCall, i) =>
val comp = ctx.freshName("comp")
s"""
- int $comp = $funCall;
- if ($comp != 0) {
- return $comp;
- }
- """
+ |int $comp = $funCall;
+ |if ($comp != 0) {
+ | return $comp;
+ |}
+ """.stripMargin
}.mkString
})
ctx.currentVars = oldCurrentVars
ctx.INPUT_ROW = oldInputRow
- // make sure INPUT_ROW is declared even if splitExpressions
- // returns an inlined block
- s"""
- |InternalRow $inputRow = null;
- |$code
- """.stripMargin
+ code
}
protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index b24d7486f3454..43116743e9952 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -350,7 +350,7 @@ case class MapValues(child: Expression)
> SELECT _FUNC_(map(1, 'a', 2, 'b'));
[{"key":1,"value":"a"},{"key":2,"value":"b"}]
""",
- since = "2.4.0")
+ since = "3.0.0")
case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
@@ -521,13 +521,18 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression {
override def checkInputDataTypes(): TypeCheckResult = {
- var funcName = s"function $prettyName"
+ val funcName = s"function $prettyName"
if (children.exists(!_.dataType.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckFailure(
s"input to $funcName should all be of type map, but it's " +
children.map(_.dataType.catalogString).mkString("[", ", ", "]"))
} else {
- TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName)
+ val sameTypeCheck = TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName)
+ if (sameTypeCheck.isFailure) {
+ sameTypeCheck
+ } else {
+ TypeUtils.checkForMapKeyType(dataType.keyType)
+ }
}
}
@@ -740,7 +745,8 @@ case class MapFromEntries(child: Expression) extends UnaryExpression {
@transient override lazy val dataType: MapType = dataTypeDetails.get._1
override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
- case Some(_) => TypeCheckResult.TypeCheckSuccess
+ case Some((mapType, _, _)) =>
+ TypeUtils.checkForMapKeyType(mapType.keyType)
case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " +
s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 0361372b6b732..6b77996789f1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -161,11 +161,11 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
"The given values of function map should all be the same type, but they are " +
values.map(_.dataType.catalogString).mkString("[", ", ", "]"))
} else {
- TypeCheckResult.TypeCheckSuccess
+ TypeUtils.checkForMapKeyType(dataType.keyType)
}
}
- override def dataType: DataType = {
+ override def dataType: MapType = {
MapType(
keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType))
.getOrElse(StringType),
@@ -224,6 +224,16 @@ case class MapFromArrays(left: Expression, right: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ val keyType = left.dataType.asInstanceOf[ArrayType].elementType
+ TypeUtils.checkForMapKeyType(keyType)
+ }
+ }
+
override def dataType: DataType = {
MapType(
keyType = left.dataType.asInstanceOf[ArrayType].elementType,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index aff372b899f86..1e4e1c663c90e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -106,6 +106,10 @@ case class CsvToStructs(
throw new AnalysisException(s"from_csv() doesn't support the ${mode.name} mode. " +
s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.")
}
+ ExprUtils.verifyColumnNameOfCorruptRecord(
+ nullableSchema,
+ parsedOptions.columnNameOfCorruptRecord)
+
val actualSchema =
StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index b07d9466ba0d1..8b31021866220 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -264,13 +264,13 @@ case class ArrayTransform(
* Filters entries in a map using the provided function.
*/
@ExpressionDescription(
-usage = "_FUNC_(expr, func) - Filters entries in a map using the function.",
-examples = """
+ usage = "_FUNC_(expr, func) - Filters entries in a map using the function.",
+ examples = """
Examples:
> SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v);
{1:0,3:-1}
""",
-since = "2.4.0")
+ since = "3.0.0")
case class MapFilter(
argument: Expression,
function: Expression)
@@ -504,7 +504,7 @@ case class ArrayAggregate(
> SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
{2:1,4:2,6:3}
""",
- since = "2.4.0")
+ since = "3.0.0")
case class TransformKeys(
argument: Expression,
function: Expression)
@@ -514,6 +514,10 @@ case class TransformKeys(
override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ TypeUtils.checkForMapKeyType(function.dataType)
+ }
+
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}
@@ -554,7 +558,7 @@ case class TransformKeys(
> SELECT _FUNC_(map_from_arrays(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
{1:2,2:4,3:6}
""",
- since = "2.4.0")
+ since = "3.0.0")
case class TransformValues(
argument: Expression,
function: Expression)
@@ -605,7 +609,7 @@ case class TransformValues(
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2));
{1:"ax",2:"by"}
""",
- since = "2.4.0")
+ since = "3.0.0")
case class MapZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index eafcb6161036e..47304d835fdf8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -550,15 +550,23 @@ case class JsonToStructs(
s"Input schema ${nullableSchema.catalogString} must be a struct, an array or a map.")
}
- // This converts parsed rows to the desired output by the given schema.
@transient
- lazy val converter = nullableSchema match {
- case _: StructType =>
- (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null
- case _: ArrayType =>
- (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null
- case _: MapType =>
- (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null
+ private lazy val castRow = nullableSchema match {
+ case _: StructType => (row: InternalRow) => row
+ case _: ArrayType => (row: InternalRow) => row.getArray(0)
+ case _: MapType => (row: InternalRow) => row.getMap(0)
+ }
+
+ // This converts parsed rows to the desired output by the given schema.
+ private def convertRow(rows: Iterator[InternalRow]) = {
+ if (rows.hasNext) {
+ val result = rows.next()
+ // JSON's parser produces one record only.
+ assert(!rows.hasNext)
+ castRow(result)
+ } else {
+ throw new IllegalArgumentException("Expected one row from JSON parser.")
+ }
}
val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
@@ -569,14 +577,17 @@ case class JsonToStructs(
throw new IllegalArgumentException(s"from_json() doesn't support the ${mode.name} mode. " +
s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.")
}
- val rawParser = new JacksonParser(nullableSchema, parsedOptions, allowArrayAsStructs = false)
- val createParser = CreateJacksonParser.utf8String _
-
- val parserSchema = nullableSchema match {
- case s: StructType => s
- case other => StructType(StructField("value", other) :: Nil)
+ val (parserSchema, actualSchema) = nullableSchema match {
+ case s: StructType =>
+ ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord)
+ (s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)))
+ case other =>
+ (StructType(StructField("value", other) :: Nil), other)
}
+ val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false)
+ val createParser = CreateJacksonParser.utf8String _
+
new FailureSafeParser[UTF8String](
input => rawParser.parse(input, createParser, identity[UTF8String]),
mode,
@@ -591,7 +602,7 @@ case class JsonToStructs(
copy(timeZoneId = Option(timeZoneId))
override def nullSafeEval(json: Any): Any = {
- converter(parser.parse(json.asInstanceOf[UTF8String]))
+ convertRow(parser.parse(json.asInstanceOf[UTF8String]))
}
override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 584a2946bd564..02b48f9e30f2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -115,6 +115,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
def withQualifier(newQualifier: Seq[String]): Attribute
def withName(newName: String): Attribute
def withMetadata(newMetadata: Metadata): Attribute
+ def withExprId(newExprId: ExprId): Attribute
override def toAttribute: Attribute = this
def newInstance(): Attribute
@@ -129,6 +130,9 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
* Note that exprId and qualifiers are in a separate parameter list because
* we only pattern match on child and name.
*
+ * Note that when creating a new Alias, all the [[AttributeReference]] that refer to
+ * the original alias should be updated to the new one.
+ *
* @param child The computation being performed
* @param name The name to be associated with the result of computing [[child]].
* @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this
@@ -299,7 +303,7 @@ case class AttributeReference(
}
}
- def withExprId(newExprId: ExprId): AttributeReference = {
+ override def withExprId(newExprId: ExprId): AttributeReference = {
if (exprId == newExprId) {
this
} else {
@@ -362,6 +366,8 @@ case class PrettyAttribute(
throw new UnsupportedOperationException
override def qualifier: Seq[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
+ override def withExprId(newExprId: ExprId): Attribute =
+ throw new UnsupportedOperationException
override def nullable: Boolean = true
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 4fd36a47cef52..59c897b6a53ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -462,12 +462,12 @@ case class NewInstance(
val d = outerObj.getClass +: paramTypes
val c = getConstructor(outerObj.getClass +: paramTypes)
(args: Seq[AnyRef]) => {
- c.newInstance(outerObj +: args: _*)
+ c(outerObj +: args)
}
}.getOrElse {
val c = getConstructor(paramTypes)
(args: Seq[AnyRef]) => {
- c.newInstance(args: _*)
+ c(args)
}
}
}
@@ -486,10 +486,16 @@ case class NewInstance(
ev.isNull = resultIsNull
- val constructorCall = outer.map { gen =>
- s"${gen.value}.new ${cls.getSimpleName}($argString)"
- }.getOrElse {
- s"new $className($argString)"
+ val constructorCall = cls.getConstructors.size match {
+ // If there are no constructors, the `new` method will fail. In
+ // this case we can try to call the apply method constructor
+ // that might be defined on the companion object.
+ case 0 => s"$className$$.MODULE$$.apply($argString)"
+ case _ => outer.map { gen =>
+ s"${gen.value}.new ${cls.getSimpleName}($argString)"
+ }.getOrElse {
+ s"new $className($argString)"
+ }
}
val code = code"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 64152e04928d2..e10b8a327c01a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -76,16 +76,19 @@ private[sql] class JSONOptions(
// Whether to ignore column of all null values or empty array/struct during schema inference
val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false)
+ // A language tag in IETF BCP 47 format
+ val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US)
+
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
// Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
val dateFormat: FastDateFormat =
- FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), Locale.US)
+ FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale)
val timestampFormat: FastDateFormat =
FastDateFormat.getInstance(
- parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US)
+ parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale)
val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index 57c7f2faf3107..773ff5a7a4013 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -399,7 +399,7 @@ class JacksonParser(
// a null first token is equivalent to testing for input.trim.isEmpty
// but it works on any token stream and not just strings
parser.nextToken() match {
- case null => Nil
+ case null => throw new RuntimeException("Not found any JSON token")
case _ => rootConverter.apply(parser) match {
case null => throw new RuntimeException("Root converter returned null")
case rows => rows
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a50a94ea2765f..30d6f6b5db470 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -84,7 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
SimplifyConditionals,
RemoveDispensableExpressions,
SimplifyBinaryComparison,
- ReplaceNullWithFalse,
+ ReplaceNullWithFalseInPredicate,
PruneFilters,
EliminateSorts,
SimplifyCasts,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
new file mode 100644
index 0000000000000..72a60f692ac78
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If}
+import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or}
+import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.BooleanType
+import org.apache.spark.util.Utils
+
+
+/**
+ * A rule that replaces `Literal(null, BooleanType)` with `FalseLiteral`, if possible, in the search
+ * condition of the WHERE/HAVING/ON(JOIN) clauses, which contain an implicit Boolean operator
+ * "(search condition) = TRUE". The replacement is only valid when `Literal(null, BooleanType)` is
+ * semantically equivalent to `FalseLiteral` when evaluating the whole search condition.
+ *
+ * Please note that FALSE and NULL are not exchangeable in most cases, when the search condition
+ * contains NOT and NULL-tolerant expressions. Thus, the rule is very conservative and applicable
+ * in very limited cases.
+ *
+ * For example, `Filter(Literal(null, BooleanType))` is equal to `Filter(FalseLiteral)`.
+ *
+ * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`;
+ * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually
+ * `Filter(FalseLiteral)`.
+ *
+ * Moreover, this rule also transforms predicates in all [[If]] expressions as well as branch
+ * conditions in all [[CaseWhen]] expressions, even if they are not part of the search conditions.
+ *
+ * For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))` can be simplified
+ * into `Project(Literal(2))`.
+ */
+object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
+ case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond)))
+ case p: LogicalPlan => p transformExpressions {
+ case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
+ case cw @ CaseWhen(branches, _) =>
+ val newBranches = branches.map { case (cond, value) =>
+ replaceNullWithFalse(cond) -> value
+ }
+ cw.copy(branches = newBranches)
+ case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) =>
+ val newLambda = lf.copy(function = replaceNullWithFalse(func))
+ af.copy(function = newLambda)
+ case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) =>
+ val newLambda = lf.copy(function = replaceNullWithFalse(func))
+ ae.copy(function = newLambda)
+ case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) =>
+ val newLambda = lf.copy(function = replaceNullWithFalse(func))
+ mf.copy(function = newLambda)
+ }
+ }
+
+ /**
+ * Recursively traverse the Boolean-type expression to replace
+ * `Literal(null, BooleanType)` with `FalseLiteral`, if possible.
+ *
+ * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit
+ * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or
+ * `Literal(null, BooleanType)`.
+ */
+ private def replaceNullWithFalse(e: Expression): Expression = e match {
+ case Literal(null, BooleanType) =>
+ FalseLiteral
+ case And(left, right) =>
+ And(replaceNullWithFalse(left), replaceNullWithFalse(right))
+ case Or(left, right) =>
+ Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
+ case cw: CaseWhen if cw.dataType == BooleanType =>
+ val newBranches = cw.branches.map { case (cond, value) =>
+ replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
+ }
+ val newElseValue = cw.elseValue.map(replaceNullWithFalse)
+ CaseWhen(newBranches, newElseValue)
+ case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
+ If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
+ case e if e.dataType == BooleanType =>
+ e
+ case e =>
+ val message = "Expected a Boolean type expression in replaceNullWithFalse, " +
+ s"but got the type `${e.dataType.catalogString}` in `${e.sql}`."
+ if (Utils.isTesting) {
+ throw new IllegalArgumentException(message)
+ } else {
+ logWarning(message)
+ e
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 2b29b49d00ab9..468a950fb1087 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -736,60 +736,3 @@ object CombineConcats extends Rule[LogicalPlan] {
flattenConcats(concat)
}
}
-
-/**
- * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further optimizations.
- *
- * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it transforms predicates
- * in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] expressions.
- *
- * For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`.
- *
- * Another example containing branches is `Filter(If(cond, FalseLiteral, Literal(null, _)))`;
- * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, and eventually
- * `Filter(FalseLiteral)`.
- *
- * As this rule is not limited to conditions in [[Filter]] and [[Join]], arbitrary plans can
- * benefit from it. For example, `Project(If(And(cond, Literal(null)), Literal(1), Literal(2)))`
- * can be simplified into `Project(Literal(2))`.
- *
- * As a result, many unnecessary computations can be removed in the query optimization phase.
- */
-object ReplaceNullWithFalse extends Rule[LogicalPlan] {
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
- case j @ Join(_, _, _, Some(cond)) => j.copy(condition = Some(replaceNullWithFalse(cond)))
- case p: LogicalPlan => p transformExpressions {
- case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
- case cw @ CaseWhen(branches, _) =>
- val newBranches = branches.map { case (cond, value) =>
- replaceNullWithFalse(cond) -> value
- }
- cw.copy(branches = newBranches)
- }
- }
-
- /**
- * Recursively replaces `Literal(null, _)` with `FalseLiteral`.
- *
- * Note that `transformExpressionsDown` can not be used here as we must stop as soon as we hit
- * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or `Literal(null, _)`.
- */
- private def replaceNullWithFalse(e: Expression): Expression = e match {
- case cw: CaseWhen if cw.dataType == BooleanType =>
- val newBranches = cw.branches.map { case (cond, value) =>
- replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
- }
- val newElseValue = cw.elseValue.map(replaceNullWithFalse)
- CaseWhen(newBranches, newElseValue)
- case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
- If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
- case And(left, right) =>
- And(replaceNullWithFalse(left), replaceNullWithFalse(right))
- case Or(left, right) =>
- Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
- case Literal(null, _) => FalseLiteral
- case _ => e
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 7149edee0173e..6ebb194d71c2e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
}
/**
- * PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
- * and pull them out from join condition. For python udf accessing attributes from only one side,
- * they are pushed down by operation push down rules. If not (e.g. user disables filter push
- * down rules), we need to pull them out in this rule too.
+ * PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides.
+ * See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them
+ * out from join condition.
*/
object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
- def hasPythonUDF(expression: Expression): Boolean = {
- expression.collectFirst { case udf: PythonUDF => udf }.isDefined
+
+ private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = {
+ expr.find { e =>
+ PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right)
+ }.isDefined
}
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case j @ Join(_, _, joinType, condition)
- if condition.isDefined && hasPythonUDF(condition.get) =>
+ case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) =>
if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
// The current strategy only support InnerLike and LeftSemi join because for other type,
// it breaks SQL semantic if we run the join condition as a filter after join. If we pass
@@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH
}
// If condition expression contains python udf, it will be moved out from
// the new join conditions.
- val (udf, rest) =
- splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
+ val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j))
val newCondition = if (rest.isEmpty) {
- logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
+ logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," +
s" it will be moved out and the join plan will be turned to cross join.")
None
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index f09c5ceefed13..a26ec4eed8648 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -17,15 +17,15 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.{AliasIdentifier}
+import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning}
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
import org.apache.spark.util.random.RandomSampler
/**
@@ -485,7 +485,7 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)])
override def output: Seq[Attribute] = child.output
override def simpleString: String = {
- val cteAliases = Utils.truncatedString(cteRelations.map(_._1), "[", ", ", "]")
+ val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]")
s"CTE $cteAliases"
}
@@ -575,6 +575,18 @@ case class Range(
}
}
+/**
+ * This is a Group by operator with the aggregate functions and projections.
+ *
+ * @param groupingExpressions expressions for grouping keys
+ * @param aggregateExpressions expressions for a project list, which could contain
+ * [[AggregateFunction]]s.
+ *
+ * Note: Currently, aggregateExpressions is the project list of this Group by operator. Before
+ * separating projection from grouping and aggregate, we should avoid expression-level optimization
+ * on aggregateExpressions, which could reference an expression in groupingExpressions.
+ * For example, see the rule [[org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps]]
+ */
case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index e991a2dc7462f..cf6ff4f986399 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.rules
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.sideBySide
@@ -66,6 +67,17 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
*/
protected def isPlanIntegral(plan: TreeType): Boolean = true
+ /**
+ * Executes the batches of rules defined by the subclass, and also tracks timing info for each
+ * rule using the provided tracker.
+ * @see [[execute]]
+ */
+ def executeAndTrack(plan: TreeType, tracker: QueryPlanningTracker): TreeType = {
+ QueryPlanningTracker.withTracker(tracker) {
+ execute(plan)
+ }
+ }
+
/**
* Executes the batches of rules defined by the subclass. The batches are executed serially
* using the defined execution strategy. Within each batch, rules are also executed serially.
@@ -74,6 +86,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
var curPlan = plan
val queryExecutionMetrics = RuleExecutor.queryExecutionMeter
val planChangeLogger = new PlanChangeLogger()
+ val tracker: Option[QueryPlanningTracker] = QueryPlanningTracker.get
batches.foreach { batch =>
val batchStartPlan = curPlan
@@ -88,8 +101,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
val startTime = System.nanoTime()
val result = rule(plan)
val runTime = System.nanoTime() - startTime
+ val effective = !result.fastEquals(plan)
- if (!result.fastEquals(plan)) {
+ if (effective) {
queryExecutionMetrics.incNumEffectiveExecution(rule.ruleName)
queryExecutionMetrics.incTimeEffectiveExecutionBy(rule.ruleName, runTime)
planChangeLogger.log(rule.ruleName, plan, result)
@@ -97,6 +111,9 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
queryExecutionMetrics.incExecutionTimeBy(rule.ruleName, runTime)
queryExecutionMetrics.incNumExecution(rule.ruleName)
+ // Record timing information using QueryPlanningTracker
+ tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective))
+
// Run the structural integrity checker against the plan after each rule.
if (!isPlanIntegral(result)) {
val message = s"After applying rule ${rule.ruleName} in batch ${batch.name}, " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 566b14d7d0e19..624680eb8ec0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql.catalyst.trees
+import java.io.Writer
import java.util.UUID
import scala.collection.Map
import scala.reflect.ClassTag
+import org.apache.commons.io.output.StringBuilderWriter
import org.apache.commons.lang3.ClassUtils
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
@@ -35,9 +37,9 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.Utils
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int)
@@ -440,10 +442,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case tn: TreeNode[_] => tn.simpleString :: Nil
case seq: Seq[Any] if seq.toSet.subsetOf(allChildren.asInstanceOf[Set[Any]]) => Nil
case iter: Iterable[_] if iter.isEmpty => Nil
- case seq: Seq[_] => Utils.truncatedString(seq, "[", ", ", "]") :: Nil
- case set: Set[_] => Utils.truncatedString(set.toSeq, "{", ", ", "}") :: Nil
+ case seq: Seq[_] => truncatedString(seq, "[", ", ", "]") :: Nil
+ case set: Set[_] => truncatedString(set.toSeq, "{", ", ", "}") :: Nil
case array: Array[_] if array.isEmpty => Nil
- case array: Array[_] => Utils.truncatedString(array, "[", ", ", "]") :: Nil
+ case array: Array[_] => truncatedString(array, "[", ", ", "]") :: Nil
case null => Nil
case None => Nil
case Some(null) => Nil
@@ -471,7 +473,20 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
def treeString: String = treeString(verbose = true)
def treeString(verbose: Boolean, addSuffix: Boolean = false): String = {
- generateTreeString(0, Nil, new StringBuilder, verbose = verbose, addSuffix = addSuffix).toString
+ val writer = new StringBuilderWriter()
+ try {
+ treeString(writer, verbose, addSuffix)
+ writer.toString
+ } finally {
+ writer.close()
+ }
+ }
+
+ def treeString(
+ writer: Writer,
+ verbose: Boolean,
+ addSuffix: Boolean): Unit = {
+ generateTreeString(0, Nil, writer, verbose, "", addSuffix)
}
/**
@@ -523,7 +538,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
protected def innerChildren: Seq[TreeNode[_]] = Seq.empty
/**
- * Appends the string representation of this node and its children to the given StringBuilder.
+ * Appends the string representation of this node and its children to the given Writer.
*
* The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at
* depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and
@@ -534,16 +549,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
- builder: StringBuilder,
+ writer: Writer,
verbose: Boolean,
prefix: String = "",
- addSuffix: Boolean = false): StringBuilder = {
+ addSuffix: Boolean = false): Unit = {
if (depth > 0) {
lastChildren.init.foreach { isLast =>
- builder.append(if (isLast) " " else ": ")
+ writer.write(if (isLast) " " else ": ")
}
- builder.append(if (lastChildren.last) "+- " else ":- ")
+ writer.write(if (lastChildren.last) "+- " else ":- ")
}
val str = if (verbose) {
@@ -551,27 +566,25 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
} else {
simpleString
}
- builder.append(prefix)
- builder.append(str)
- builder.append("\n")
+ writer.write(prefix)
+ writer.write(str)
+ writer.write("\n")
if (innerChildren.nonEmpty) {
innerChildren.init.foreach(_.generateTreeString(
- depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose,
+ depth + 2, lastChildren :+ children.isEmpty :+ false, writer, verbose,
addSuffix = addSuffix))
innerChildren.last.generateTreeString(
- depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose,
+ depth + 2, lastChildren :+ children.isEmpty :+ true, writer, verbose,
addSuffix = addSuffix)
}
if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(
- depth + 1, lastChildren :+ false, builder, verbose, prefix, addSuffix))
+ depth + 1, lastChildren :+ false, writer, verbose, prefix, addSuffix))
children.last.generateTreeString(
- depth + 1, lastChildren :+ true, builder, verbose, prefix, addSuffix)
+ depth + 1, lastChildren :+ true, writer, verbose, prefix, addSuffix)
}
-
- builder
}
/**
@@ -653,7 +666,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
t.forall(_.isInstanceOf[Partitioning]) || t.forall(_.isInstanceOf[DataType]) =>
JArray(t.map(parseToJson).toList)
case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] =>
- JString(Utils.truncatedString(t, "[", ", ", "]"))
+ JString(truncatedString(t, "[", ", ", "]"))
case t: Seq[_] => JNull
case m: Map[_, _] => JNull
// if it's a scala object, we can simply keep the full class path.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
index 9bacd3b925be3..ea619c6a7666c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
@@ -199,7 +199,7 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable {
var shift = 0
while (idx < m && i < REGISTERS_PER_WORD) {
val Midx = (word >>> shift) & REGISTER_WORD_MASK
- zInverse += 1.0 / (1 << Midx)
+ zInverse += 1.0 / (1L << Midx)
if (Midx == 0) {
V += 1.0d
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
index 3190e511e2cb5..2a03f85ab594b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats
* Helper class to compute approximate quantile summary.
* This implementation is based on the algorithm proposed in the paper:
* "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael
- * and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670)
+ * and Khanna, Sanjeev. (https://doi.org/10.1145/375663.375670)
*
* In order to optimize for speed, it maintains an internal buffer of the last seen samples,
* and only inserts them after crossing a certain size threshold. This guarantees a near-constant
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 76218b459ef0d..2a71fdb7592bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -46,12 +46,20 @@ object TypeUtils {
if (TypeCoercion.haveSameType(types)) {
TypeCheckResult.TypeCheckSuccess
} else {
- return TypeCheckResult.TypeCheckFailure(
+ TypeCheckResult.TypeCheckFailure(
s"input to $caller should all be the same type, but it's " +
types.map(_.catalogString).mkString("[", ", ", "]"))
}
}
+ def checkForMapKeyType(keyType: DataType): TypeCheckResult = {
+ if (keyType.existsRecursively(_.isInstanceOf[MapType])) {
+ TypeCheckResult.TypeCheckFailure("The key of map cannot be/contain map.")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
def getNumeric(t: DataType): Numeric[Any] =
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index 0978e92dd4f72..277584b20dcd2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -19,13 +19,16 @@ package org.apache.spark.sql.catalyst
import java.io._
import java.nio.charset.StandardCharsets
+import java.util.concurrent.atomic.AtomicBoolean
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{NumericType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
-package object util {
+package object util extends Logging {
/** Silences output to stderr or stdout for the duration of f */
def quietly[A](f: => A): A = {
@@ -167,6 +170,38 @@ package object util {
builder.toString()
}
+ /** Whether we have warned about plan string truncation yet. */
+ private val truncationWarningPrinted = new AtomicBoolean(false)
+
+ /**
+ * Format a sequence with semantics similar to calling .mkString(). Any elements beyond
+ * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder.
+ *
+ * @return the trimmed and formatted string.
+ */
+ def truncatedString[T](
+ seq: Seq[T],
+ start: String,
+ sep: String,
+ end: String,
+ maxNumFields: Int = SQLConf.get.maxToStringFields): String = {
+ if (seq.length > maxNumFields) {
+ if (truncationWarningPrinted.compareAndSet(false, true)) {
+ logWarning(
+ "Truncated the string representation of a plan since it was too large. This " +
+ s"behavior can be adjusted by setting '${SQLConf.MAX_TO_STRING_FIELDS.key}'.")
+ }
+ val numFields = math.max(0, maxNumFields - 1)
+ seq.take(numFields).mkString(
+ start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end)
+ } else {
+ seq.mkString(start, sep, end)
+ }
+ }
+
+ /** Shorthand for calling truncatedString() without start or end strings. */
+ def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "")
+
/* FIX ME
implicit class debugLogging(a: Any) {
def debugLogging() {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index cf67ac1ff813c..be4496c87293e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1611,6 +1611,22 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE =
+ buildConf("spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue")
+ .internal()
+ .doc("When set to true, the key attribute resulted from running `Dataset.groupByKey` " +
+ "for non-struct key type, will be named as `value`, following the behavior of Spark " +
+ "version 2.4 and earlier.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val MAX_TO_STRING_FIELDS = buildConf("spark.sql.debug.maxToStringFields")
+ .doc("Maximum number of fields of sequence-like entries can be converted to strings " +
+ "in debug output. Any elements beyond the limit will be dropped and replaced by a" +
+ """ "... N more fields" placeholder.""")
+ .intConf
+ .createWithDefault(25)
+
val MAX_REPEATED_ALIAS_SIZE =
buildConf("spark.sql.maxRepeatedAliasSize")
.internal()
@@ -2047,6 +2063,11 @@ class SQLConf extends Serializable with Logging {
def integralDivideReturnLong: Boolean = getConf(SQLConf.LEGACY_INTEGRALDIVIDE_RETURN_LONG)
+ def nameNonStructGroupingKeyAsValue: Boolean =
+ getConf(SQLConf.NAME_NON_STRUCT_GROUPING_KEY_AS_VALUE)
+
+ def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS)
+
def maxRepeatedAliasSize: Int = getConf(SQLConf.MAX_REPEATED_ALIAS_SIZE)
/** ********************** SQLConf functionality methods ************ */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index c43cc748655e8..5367ce2af8e9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.types
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.expressions.Expression
/**
@@ -134,7 +134,7 @@ object AtomicType {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
abstract class NumericType extends AtomicType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 58c75b5dc7a35..7465569868f07 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -21,7 +21,7 @@ import scala.math.Ordering
import org.json4s.JsonDSL._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.util.ArrayData
/**
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
object ArrayType extends AbstractDataType {
/**
* Construct a [[ArrayType]] object with the given element type. The `containsNull` is true.
@@ -60,7 +60,7 @@ object ArrayType extends AbstractDataType {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType {
/** No-arg constructor for kryo. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
index 032d6b54aeb79..cc8b3e6e399a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
@@ -20,15 +20,14 @@ package org.apache.spark.sql.types
import scala.math.Ordering
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.util.TypeUtils
-
/**
* The data type representing `Array[Byte]` values.
* Please use the singleton `DataTypes.BinaryType`.
*/
-@InterfaceStability.Stable
+@Stable
class BinaryType private() extends AtomicType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "BinaryType$" in byte code.
@@ -55,5 +54,5 @@ class BinaryType private() extends AtomicType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object BinaryType extends BinaryType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala
index 63f354d2243cf..5e3de71caa37e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala
@@ -20,15 +20,14 @@ package org.apache.spark.sql.types
import scala.math.Ordering
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
-
+import org.apache.spark.annotation.Stable
/**
* The data type representing `Boolean` values. Please use the singleton `DataTypes.BooleanType`.
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class BooleanType private() extends AtomicType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "BooleanType$" in byte code.
@@ -48,5 +47,5 @@ class BooleanType private() extends AtomicType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object BooleanType extends BooleanType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala
index 5854c3f5ba116..9d400eefc0f8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.types
import scala.math.{Integral, Numeric, Ordering}
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
/**
* The data type representing `Byte` values. Please use the singleton `DataTypes.ByteType`.
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class ByteType private() extends IntegralType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "ByteType$" in byte code.
@@ -52,5 +52,5 @@ class ByteType private() extends IntegralType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object ByteType extends ByteType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
index 2342036a57460..8e297874a0d62 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.types
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
/**
* The data type representing calendar time intervals. The calendar time interval is stored
@@ -29,7 +29,7 @@ import org.apache.spark.annotation.InterfaceStability
*
* @since 1.5.0
*/
-@InterfaceStability.Stable
+@Stable
class CalendarIntervalType private() extends DataType {
override def defaultSize: Int = 16
@@ -40,5 +40,5 @@ class CalendarIntervalType private() extends DataType {
/**
* @since 1.5.0
*/
-@InterfaceStability.Stable
+@Stable
case object CalendarIntervalType extends CalendarIntervalType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index e53628d11ccf3..c58f7a2397374 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -26,7 +26,7 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -38,7 +38,7 @@ import org.apache.spark.util.Utils
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
abstract class DataType extends AbstractDataType {
/**
* Enables matching against DataType for expressions:
@@ -111,7 +111,7 @@ abstract class DataType extends AbstractDataType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
object DataType {
private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r
@@ -180,7 +180,7 @@ object DataType {
("pyClass", _),
("sqlType", _),
("type", JString("udt"))) =>
- Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
+ Utils.classForName(udtClass).getConstructor().newInstance().asInstanceOf[UserDefinedType[_]]
// Python UDT
case JSortedObject(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala
index 9e70dd486a125..7491014b22dab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import scala.math.Ordering
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
/**
* A date type, supporting "0001-01-01" through "9999-12-31".
@@ -31,7 +31,7 @@ import org.apache.spark.annotation.InterfaceStability
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class DateType private() extends AtomicType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "DateType$" in byte code.
@@ -53,5 +53,5 @@ class DateType private() extends AtomicType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object DateType extends DateType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 9eed2eb202045..0192059a3a39f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import java.lang.{Long => JLong}
import java.math.{BigInteger, MathContext, RoundingMode}
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.AnalysisException
/**
@@ -31,7 +31,7 @@ import org.apache.spark.sql.AnalysisException
* - If decimalVal is set, it represents the whole decimal value
* - Otherwise, the decimal value is longVal / (10 ** _scale)
*/
-@InterfaceStability.Unstable
+@Unstable
final class Decimal extends Ordered[Decimal] with Serializable {
import org.apache.spark.sql.types.Decimal._
@@ -185,9 +185,21 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}
- def toScalaBigInt: BigInt = BigInt(toLong)
+ def toScalaBigInt: BigInt = {
+ if (decimalVal.ne(null)) {
+ decimalVal.toBigInt()
+ } else {
+ BigInt(toLong)
+ }
+ }
- def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong)
+ def toJavaBigInteger: java.math.BigInteger = {
+ if (decimalVal.ne(null)) {
+ decimalVal.underlying().toBigInteger()
+ } else {
+ java.math.BigInteger.valueOf(toLong)
+ }
+ }
def toUnscaledLong: Long = {
if (decimalVal.ne(null)) {
@@ -407,7 +419,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}
-@InterfaceStability.Unstable
+@Unstable
object Decimal {
val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP
val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 15004e4b9667d..25eddaf06a780 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -21,11 +21,10 @@ import java.util.Locale
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
-
/**
* The data type representing `java.math.BigDecimal` values.
* A Decimal that must have fixed precision (the maximum number of digits) and scale (the number
@@ -39,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class DecimalType(precision: Int, scale: Int) extends FractionalType {
if (scale > precision) {
@@ -110,7 +109,7 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
object DecimalType extends AbstractDataType {
import scala.math.min
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
index a5c79ff01ca06..afd3353397019 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
@@ -21,7 +21,7 @@ import scala.math.{Fractional, Numeric, Ordering}
import scala.math.Numeric.DoubleAsIfIntegral
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.util.Utils
/**
@@ -29,7 +29,7 @@ import org.apache.spark.util.Utils
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class DoubleType private() extends FractionalType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "DoubleType$" in byte code.
@@ -54,5 +54,5 @@ class DoubleType private() extends FractionalType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object DoubleType extends DoubleType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
index 352147ec936c9..6d98987304081 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
@@ -21,7 +21,7 @@ import scala.math.{Fractional, Numeric, Ordering}
import scala.math.Numeric.FloatAsIfIntegral
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.util.Utils
/**
@@ -29,7 +29,7 @@ import org.apache.spark.util.Utils
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class FloatType private() extends FractionalType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "FloatType$" in byte code.
@@ -55,5 +55,5 @@ class FloatType private() extends FractionalType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object FloatType extends FloatType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala
index a85e3729188d9..0755202d20df1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.types
import scala.math.{Integral, Numeric, Ordering}
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
/**
* The data type representing `Int` values. Please use the singleton `DataTypes.IntegerType`.
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class IntegerType private() extends IntegralType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "IntegerType$" in byte code.
@@ -51,5 +51,5 @@ class IntegerType private() extends IntegralType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object IntegerType extends IntegerType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala
index 0997028fc1057..3c49c721fdc88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.types
import scala.math.{Integral, Numeric, Ordering}
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
/**
* The data type representing `Long` values. Please use the singleton `DataTypes.LongType`.
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class LongType private() extends IntegralType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "LongType$" in byte code.
@@ -51,5 +51,5 @@ class LongType private() extends IntegralType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object LongType extends LongType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index 594e155268bf6..29b9ffc0c3549 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
/**
* The data type for Maps. Keys in a map are not allowed to have `null` values.
@@ -31,7 +31,7 @@ import org.apache.spark.annotation.InterfaceStability
* @param valueType The data type of map values.
* @param valueContainsNull Indicates if map values have `null` values.
*/
-@InterfaceStability.Stable
+@Stable
case class MapType(
keyType: DataType,
valueType: DataType,
@@ -78,7 +78,7 @@ case class MapType(
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
object MapType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
index 7c15dc0de4b6b..4979aced145c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable
import org.json4s._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
/**
@@ -37,7 +37,7 @@ import org.apache.spark.annotation.InterfaceStability
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
sealed class Metadata private[types] (private[types] val map: Map[String, Any])
extends Serializable {
@@ -117,7 +117,7 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any])
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
object Metadata {
private[this] val _empty = new Metadata(Map.empty)
@@ -228,7 +228,7 @@ object Metadata {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class MetadataBuilder {
private val map: mutable.Map[String, Any] = mutable.Map.empty
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala
index 494225b47a270..14097a5280d50 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala
@@ -17,15 +17,14 @@
package org.apache.spark.sql.types
-import org.apache.spark.annotation.InterfaceStability
-
+import org.apache.spark.annotation.Stable
/**
* The data type representing `NULL` values. Please use the singleton `DataTypes.NullType`.
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class NullType private() extends DataType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "NullType$" in byte code.
@@ -38,5 +37,5 @@ class NullType private() extends DataType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object NullType extends NullType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
index 203e85e1c99bd..6756b209f432e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql.types
import scala.language.existentials
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
-@InterfaceStability.Evolving
+@Evolving
object ObjectType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType =
throw new UnsupportedOperationException(
@@ -38,7 +38,7 @@ object ObjectType extends AbstractDataType {
/**
* Represents a JVM object that is passing through Spark SQL expression evaluation.
*/
-@InterfaceStability.Evolving
+@Evolving
case class ObjectType(cls: Class[_]) extends DataType {
override def defaultSize: Int = 4096
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala
index ee655c338b59f..9b5ddfef1ccf5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.types
import scala.math.{Integral, Numeric, Ordering}
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
/**
* The data type representing `Short` values. Please use the singleton `DataTypes.ShortType`.
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class ShortType private() extends IntegralType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "ShortType$" in byte code.
@@ -51,5 +51,5 @@ class ShortType private() extends IntegralType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object ShortType extends ShortType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
index 59b124cda7d14..8ce1cd078e312 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import scala.math.Ordering
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -28,7 +28,7 @@ import org.apache.spark.unsafe.types.UTF8String
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class StringType private() extends AtomicType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "StringType$" in byte code.
@@ -48,6 +48,6 @@ class StringType private() extends AtomicType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object StringType extends StringType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala
index 35f9970a0aaec..6f6b561d67d49 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier}
/**
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdenti
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class StructField(
name: String,
dataType: DataType,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 06289b1483203..6e8bbde7787a6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -24,10 +24,10 @@ import scala.util.control.NonFatal
import org.json4s.JsonDSL._
import org.apache.spark.SparkException
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser}
-import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier}
+import org.apache.spark.sql.catalyst.util.{quoteIdentifier, truncatedString}
import org.apache.spark.util.Utils
/**
@@ -95,7 +95,7 @@ import org.apache.spark.util.Utils
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] {
/** No-arg constructor for kryo. */
@@ -346,7 +346,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
override def simpleString: String = {
val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}")
- Utils.truncatedString(fieldTypes, "struct<", ",", ">")
+ truncatedString(fieldTypes, "struct<", ",", ">")
}
override def catalogString: String = {
@@ -422,7 +422,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
object StructType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = new StructType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala
index fdb91e0499920..a20f155418f8a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import scala.math.Ordering
import scala.reflect.runtime.universe.typeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
/**
* The data type representing `java.sql.Timestamp` values.
@@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class TimestampType private() extends AtomicType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "TimestampType$" in byte code.
@@ -50,5 +50,5 @@ class TimestampType private() extends AtomicType {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case object TimestampType extends TimestampType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala
new file mode 100644
index 0000000000000..120b284a77854
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/QueryPlanningTrackerSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import org.apache.spark.SparkFunSuite
+
+class QueryPlanningTrackerSuite extends SparkFunSuite {
+
+ test("phases") {
+ val t = new QueryPlanningTracker
+ t.measureTime("p1") {
+ Thread.sleep(1)
+ }
+
+ assert(t.phases("p1") > 0)
+ assert(!t.phases.contains("p2"))
+
+ val old = t.phases("p1")
+
+ t.measureTime("p1") {
+ Thread.sleep(1)
+ }
+ assert(t.phases("p1") > old)
+ }
+
+ test("rules") {
+ val t = new QueryPlanningTracker
+ t.recordRuleInvocation("r1", 1, effective = false)
+ t.recordRuleInvocation("r2", 2, effective = true)
+ t.recordRuleInvocation("r3", 1, effective = false)
+ t.recordRuleInvocation("r3", 2, effective = true)
+
+ val rules = t.rules
+
+ assert(rules("r1").totalTimeNs == 1)
+ assert(rules("r1").numInvocations == 1)
+ assert(rules("r1").numEffectiveInvocations == 0)
+
+ assert(rules("r2").totalTimeNs == 2)
+ assert(rules("r2").numInvocations == 1)
+ assert(rules("r2").numEffectiveInvocations == 1)
+
+ assert(rules("r3").totalTimeNs == 3)
+ assert(rules("r3").numInvocations == 2)
+ assert(rules("r3").numEffectiveInvocations == 1)
+ }
+
+ test("topRulesByTime") {
+ val t = new QueryPlanningTracker
+
+ // Return empty seq when k = 0
+ assert(t.topRulesByTime(0) == Seq.empty)
+ assert(t.topRulesByTime(1) == Seq.empty)
+
+ t.recordRuleInvocation("r2", 2, effective = true)
+ t.recordRuleInvocation("r4", 4, effective = true)
+ t.recordRuleInvocation("r1", 1, effective = false)
+ t.recordRuleInvocation("r3", 3, effective = false)
+
+ // k <= total size
+ assert(t.topRulesByTime(0) == Seq.empty)
+ val top = t.topRulesByTime(2)
+ assert(top.size == 2)
+ assert(top(0)._1 == "r4")
+ assert(top(1)._1 == "r3")
+
+ // k > total size
+ assert(t.topRulesByTime(10).size == 4)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index d98589db323cc..80824cc2a7f21 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -109,6 +109,30 @@ object TestingUDT {
}
}
+/** An example derived from Twitter/Scrooge codegen for thrift */
+object ScroogeLikeExample {
+ def apply(x: Int): ScroogeLikeExample = new Immutable(x)
+
+ def unapply(_item: ScroogeLikeExample): Option[Int] = Some(_item.x)
+
+ class Immutable(val x: Int) extends ScroogeLikeExample
+}
+
+trait ScroogeLikeExample extends Product1[Int] with Serializable {
+ import ScroogeLikeExample._
+
+ def x: Int
+
+ def _1: Int = x
+
+ override def canEqual(other: Any): Boolean = other.isInstanceOf[ScroogeLikeExample]
+
+ override def equals(other: Any): Boolean =
+ canEqual(other) &&
+ this.x == other.asInstanceOf[ScroogeLikeExample].x
+
+ override def hashCode: Int = x
+}
class ScalaReflectionSuite extends SparkFunSuite {
import org.apache.spark.sql.catalyst.ScalaReflection._
@@ -362,4 +386,11 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0)
}
+
+ test("SPARK-8288: schemaFor works for a class with only a companion object constructor") {
+ val schema = schemaFor[ScroogeLikeExample]
+ assert(schema === Schema(
+ StructType(Seq(
+ StructField("x", IntegerType, nullable = false))), nullable = true))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 94778840d706b..117e96175e92a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import scala.beans.{BeanInfo, BeanProperty}
-
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -30,8 +28,9 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
-@BeanInfo
-private[sql] case class GroupableData(@BeanProperty data: Int)
+private[sql] case class GroupableData(data: Int) {
+ def getData: Int = data
+}
private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
@@ -50,8 +49,9 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
private[spark] override def asNullable: GroupableUDT = this
}
-@BeanInfo
-private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int])
+private[sql] case class UngroupableData(data: Map[Int, Int]) {
+ def getData: Map[Int, Int] = data
+}
private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 3d7c91870133b..fab1b776a3c72 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -21,6 +21,7 @@ import java.net.URI
import java.util.Locale
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
@@ -54,7 +55,7 @@ trait AnalysisTest extends PlanTest {
expectedPlan: LogicalPlan,
caseSensitive: Boolean = true): Unit = {
val analyzer = getAnalyzer(caseSensitive)
- val actualPlan = analyzer.executeAndCheck(inputPlan)
+ val actualPlan = analyzer.executeAndCheck(inputPlan, new QueryPlanningTracker)
comparePlans(actualPlan, expectedPlan)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
index 8da4d7e3aa372..aa5eda8e5ba87 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import java.util.TimeZone
+import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -109,7 +110,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
Seq(UnresolvedAlias(Multiply(unresolved_a, Literal(2))),
unresolved_b, UnresolvedAlias(count(unresolved_c))))
- val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2)
+ val resultPlan = getAnalyzer(true).executeAndCheck(originalPlan2, new QueryPlanningTracker)
val gExpressions = resultPlan.asInstanceOf[Aggregate].groupingExpressions
assert(gExpressions.size == 3)
val firstGroupingExprAttrName =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala
index fe57c199b8744..64bd07534b19b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
@@ -34,6 +35,7 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest {
private lazy val uuid3 = Uuid().as('_uuid3)
private lazy val uuid1Ref = uuid1.toAttribute
+ private val tracker = new QueryPlanningTracker
private val analyzer = getAnalyzer(caseSensitive = true)
private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = {
@@ -47,7 +49,7 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest {
test("analyzed plan sets random seed for Uuid expression") {
val plan = r.select(a, uuid1)
- val resolvedPlan = analyzer.executeAndCheck(plan)
+ val resolvedPlan = analyzer.executeAndCheck(plan, tracker)
getUuidExpressions(resolvedPlan).foreach { u =>
assert(u.resolved)
assert(u.randomSeed.isDefined)
@@ -56,14 +58,14 @@ class ResolvedUuidExpressionsSuite extends AnalysisTest {
test("Uuid expressions should have different random seeds") {
val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
- val resolvedPlan = analyzer.executeAndCheck(plan)
+ val resolvedPlan = analyzer.executeAndCheck(plan, tracker)
assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3)
}
test("Different analyzed plans should have different random seeds in Uuids") {
val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
- val resolvedPlan1 = analyzer.executeAndCheck(plan)
- val resolvedPlan2 = analyzer.executeAndCheck(plan)
+ val resolvedPlan1 = analyzer.executeAndCheck(plan, tracker)
+ val resolvedPlan2 = analyzer.executeAndCheck(plan, tracker)
val uuids1 = getUuidExpressions(resolvedPlan1)
val uuids2 = getUuidExpressions(resolvedPlan2)
assert(uuids1.distinct.length == 3)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index e9b100b3b30db..be8fd90c4c52a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -128,13 +128,13 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
encodeDecodeTest(-3.7f, "primitive float")
encodeDecodeTest(-3.7, "primitive double")
- encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean")
- encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte")
- encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short")
- encodeDecodeTest(new java.lang.Integer(-3), "boxed int")
- encodeDecodeTest(new java.lang.Long(-3L), "boxed long")
- encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float")
- encodeDecodeTest(new java.lang.Double(-3.7), "boxed double")
+ encodeDecodeTest(java.lang.Boolean.FALSE, "boxed boolean")
+ encodeDecodeTest(java.lang.Byte.valueOf(-3: Byte), "boxed byte")
+ encodeDecodeTest(java.lang.Short.valueOf(-3: Short), "boxed short")
+ encodeDecodeTest(java.lang.Integer.valueOf(-3), "boxed int")
+ encodeDecodeTest(java.lang.Long.valueOf(-3L), "boxed long")
+ encodeDecodeTest(java.lang.Float.valueOf(-3.7f), "boxed float")
+ encodeDecodeTest(java.lang.Double.valueOf(-3.7), "boxed double")
encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
@@ -224,7 +224,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
productTest(
RepeatedData(
Seq(1, 2),
- Seq(new Integer(1), null, new Integer(2)),
+ Seq(Integer.valueOf(1), null, Integer.valueOf(2)),
Map(1 -> 2L),
Map(1 -> null),
PrimitiveData(1, 1, 1, 1, 1, 1, true)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 2e0adbb465008..d2edb2f24688d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -25,6 +25,7 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -108,32 +109,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
}
test("Map Concat") {
- val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType,
+ val m0 = Literal.create(create_map("a" -> "1", "b" -> "2"), MapType(StringType, StringType,
valueContainsNull = false))
- val m1 = Literal.create(Map("c" -> "3", "a" -> "4"), MapType(StringType, StringType,
+ val m1 = Literal.create(create_map("c" -> "3", "a" -> "4"), MapType(StringType, StringType,
valueContainsNull = false))
- val m2 = Literal.create(Map("d" -> "4", "e" -> "5"), MapType(StringType, StringType))
- val m3 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType))
- val m4 = Literal.create(Map("a" -> null, "c" -> "3"), MapType(StringType, StringType))
- val m5 = Literal.create(Map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType))
- val m6 = Literal.create(Map("a" -> null, "c" -> 3), MapType(StringType, IntegerType))
- val m7 = Literal.create(Map(List(1, 2) -> 1, List(3, 4) -> 2),
+ val m2 = Literal.create(create_map("d" -> "4", "e" -> "5"), MapType(StringType, StringType))
+ val m3 = Literal.create(create_map("a" -> "1", "b" -> "2"), MapType(StringType, StringType))
+ val m4 = Literal.create(create_map("a" -> null, "c" -> "3"), MapType(StringType, StringType))
+ val m5 = Literal.create(create_map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType))
+ val m6 = Literal.create(create_map("a" -> null, "c" -> 3), MapType(StringType, IntegerType))
+ val m7 = Literal.create(create_map(List(1, 2) -> 1, List(3, 4) -> 2),
MapType(ArrayType(IntegerType), IntegerType))
- val m8 = Literal.create(Map(List(5, 6) -> 3, List(1, 2) -> 4),
+ val m8 = Literal.create(create_map(List(5, 6) -> 3, List(1, 2) -> 4),
MapType(ArrayType(IntegerType), IntegerType))
- val m9 = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2),
- MapType(MapType(IntegerType, IntegerType), IntegerType))
- val m10 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4),
- MapType(MapType(IntegerType, IntegerType), IntegerType))
- val m11 = Literal.create(Map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType,
+ val m9 = Literal.create(create_map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType,
valueContainsNull = false))
- val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType,
+ val m10 = Literal.create(create_map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType,
valueContainsNull = false))
- val m13 = Literal.create(Map(1 -> 2, 3 -> 4),
+ val m11 = Literal.create(create_map(1 -> 2, 3 -> 4),
MapType(IntegerType, IntegerType, valueContainsNull = false))
- val m14 = Literal.create(Map(5 -> 6),
+ val m12 = Literal.create(create_map(5 -> 6),
MapType(IntegerType, IntegerType, valueContainsNull = false))
- val m15 = Literal.create(Map(7 -> null),
+ val m13 = Literal.create(create_map(7 -> null),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val mNull = Literal.create(null, MapType(StringType, StringType))
@@ -147,7 +144,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
// maps with no overlap
checkEvaluation(MapConcat(Seq(m0, m2)),
- Map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5"))
+ create_map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5"))
// 3 maps
checkEvaluation(MapConcat(Seq(m0, m1, m2)),
@@ -174,7 +171,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
)
// keys that are primitive
- checkEvaluation(MapConcat(Seq(m11, m12)),
+ checkEvaluation(MapConcat(Seq(m9, m10)),
(
Array(1, 2, 3, 4), // keys
Array("1", "2", "3", "4") // values
@@ -189,20 +186,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
)
)
- // keys that are maps, with overlap
- checkEvaluation(MapConcat(Seq(m9, m10)),
- (
- Array(Map(1 -> 2, 3 -> 4), Map(5 -> 6, 7 -> 8), Map(9 -> 10, 11 -> 12),
- Map(1 -> 2, 3 -> 4)), // keys
- Array(1, 2, 3, 4) // values
- )
- )
-
// both keys and value are primitive and valueContainsNull = false
- checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6))
+ checkEvaluation(MapConcat(Seq(m11, m12)), create_map(1 -> 2, 3 -> 4, 5 -> 6))
// both keys and value are primitive and valueContainsNull = true
- checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null))
+ checkEvaluation(MapConcat(Seq(m11, m13)), create_map(1 -> 2, 3 -> 4, 7 -> null))
// null map
checkEvaluation(MapConcat(Seq(m0, mNull)), null)
@@ -211,7 +199,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapConcat(Seq(mNull)), null)
// single map
- checkEvaluation(MapConcat(Seq(m0)), Map("a" -> "1", "b" -> "2"))
+ checkEvaluation(MapConcat(Seq(m0)), create_map("a" -> "1", "b" -> "2"))
// no map
checkEvaluation(MapConcat(Seq.empty), Map.empty)
@@ -245,12 +233,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
assert(MapConcat(Seq(m1, mNull)).nullable)
val mapConcat = MapConcat(Seq(
- Literal.create(Map(Seq(1, 2) -> Seq("a", "b")),
+ Literal.create(create_map(Seq(1, 2) -> Seq("a", "b")),
MapType(
ArrayType(IntegerType, containsNull = false),
ArrayType(StringType, containsNull = false),
valueContainsNull = false)),
- Literal.create(Map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null),
+ Literal.create(create_map(Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null),
MapType(
ArrayType(IntegerType, containsNull = true),
ArrayType(StringType, containsNull = true),
@@ -264,6 +252,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq(1, 2) -> Seq("a", "b"),
Seq(3, 4, null) -> Seq("c", "d", null),
Seq(6) -> null))
+
+ // map key can't be map
+ val mapOfMap = Literal.create(Map(Map(1 -> 2, 3 -> 4) -> 1, Map(5 -> 6, 7 -> 8) -> 2),
+ MapType(MapType(IntegerType, IntegerType), IntegerType))
+ val mapOfMap2 = Literal.create(Map(Map(9 -> 10, 11 -> 12) -> 3, Map(1 -> 2, 3 -> 4) -> 4),
+ MapType(MapType(IntegerType, IntegerType), IntegerType))
+ val map = MapConcat(Seq(mapOfMap, mapOfMap2))
+ map.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key")
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("The key of map cannot be/contain map"))
+ }
}
test("MapFromEntries") {
@@ -274,20 +274,20 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
StructField("b", valueType))),
true)
}
- def r(values: Any*): InternalRow = create_row(values: _*)
+ def row(values: Any*): InternalRow = create_row(values: _*)
// Primitive-type keys and values
val aiType = arrayType(IntegerType, IntegerType)
- val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType)
- val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType)
+ val ai0 = Literal.create(Seq(row(1, 10), row(2, 20), row(3, 20)), aiType)
+ val ai1 = Literal.create(Seq(row(1, null), row(2, 20), row(3, null)), aiType)
val ai2 = Literal.create(Seq.empty, aiType)
val ai3 = Literal.create(null, aiType)
- val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType)
- val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType)
- val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType)
+ val ai4 = Literal.create(Seq(row(1, 10), row(1, 20)), aiType)
+ val ai5 = Literal.create(Seq(row(1, 10), row(null, 20)), aiType)
+ val ai6 = Literal.create(Seq(null, row(2, 20), null), aiType)
- checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20))
- checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null))
+ checkEvaluation(MapFromEntries(ai0), create_map(1 -> 10, 2 -> 20, 3 -> 20))
+ checkEvaluation(MapFromEntries(ai1), create_map(1 -> null, 2 -> 20, 3 -> null))
checkEvaluation(MapFromEntries(ai2), Map.empty)
checkEvaluation(MapFromEntries(ai3), null)
checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1))
@@ -298,23 +298,36 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
// Non-primitive-type keys and values
val asType = arrayType(StringType, StringType)
- val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType)
- val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType)
+ val as0 = Literal.create(Seq(row("a", "aa"), row("b", "bb"), row("c", "bb")), asType)
+ val as1 = Literal.create(Seq(row("a", null), row("b", "bb"), row("c", null)), asType)
val as2 = Literal.create(Seq.empty, asType)
val as3 = Literal.create(null, asType)
- val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType)
- val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType)
- val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType)
+ val as4 = Literal.create(Seq(row("a", "aa"), row("a", "bb")), asType)
+ val as5 = Literal.create(Seq(row("a", "aa"), row(null, "bb")), asType)
+ val as6 = Literal.create(Seq(null, row("b", "bb"), null), asType)
- checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb"))
- checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null))
+ checkEvaluation(MapFromEntries(as0), create_map("a" -> "aa", "b" -> "bb", "c" -> "bb"))
+ checkEvaluation(MapFromEntries(as1), create_map("a" -> null, "b" -> "bb", "c" -> null))
checkEvaluation(MapFromEntries(as2), Map.empty)
checkEvaluation(MapFromEntries(as3), null)
checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a"))
+ checkEvaluation(MapFromEntries(as6), null)
+
+ // Map key can't be null
checkExceptionInExpression[RuntimeException](
MapFromEntries(as5),
"The first field from a struct (key) can't be null.")
- checkEvaluation(MapFromEntries(as6), null)
+
+ // map key can't be map
+ val structOfMap = row(create_map(1 -> 1), 1)
+ val map = MapFromEntries(Literal.create(
+ Seq(structOfMap),
+ arrayType(keyType = MapType(IntegerType, IntegerType), valueType = IntegerType)))
+ map.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key")
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("The key of map cannot be/contain map"))
+ }
}
test("Sort Array") {
@@ -1645,6 +1658,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull === true)
}
+ test("Array Except - null handling") {
+ val empty = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false))
+ val oneNull = Literal.create(Seq(null), ArrayType(IntegerType))
+ val twoNulls = Literal.create(Seq(null, null), ArrayType(IntegerType))
+
+ checkEvaluation(ArrayExcept(oneNull, oneNull), Seq.empty)
+ checkEvaluation(ArrayExcept(twoNulls, twoNulls), Seq.empty)
+ checkEvaluation(ArrayExcept(twoNulls, oneNull), Seq.empty)
+ checkEvaluation(ArrayExcept(empty, oneNull), Seq.empty)
+ checkEvaluation(ArrayExcept(oneNull, empty), Seq(null))
+ checkEvaluation(ArrayExcept(twoNulls, empty), Seq(null))
+ }
+
test("Array Intersect") {
val a00 = Literal.create(Seq(1, 2, 4), ArrayType(IntegerType, false))
val a01 = Literal.create(Seq(4, 2), ArrayType(IntegerType, false))
@@ -1756,4 +1782,17 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false)
assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull === true)
}
+
+ test("Array Intersect - null handling") {
+ val empty = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false))
+ val oneNull = Literal.create(Seq(null), ArrayType(IntegerType))
+ val twoNulls = Literal.create(Seq(null, null), ArrayType(IntegerType))
+
+ checkEvaluation(ArrayIntersect(oneNull, oneNull), Seq(null))
+ checkEvaluation(ArrayIntersect(twoNulls, twoNulls), Seq(null))
+ checkEvaluation(ArrayIntersect(twoNulls, oneNull), Seq(null))
+ checkEvaluation(ArrayIntersect(oneNull, twoNulls), Seq(null))
+ checkEvaluation(ArrayIntersect(empty, oneNull), Seq.empty)
+ checkEvaluation(ArrayIntersect(oneNull, empty), Seq.empty)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 77aaf55480ec2..d95f42e04e37c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._
@@ -158,40 +158,32 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
keys.zip(values).flatMap { case (k, v) => Seq(k, v) }
}
- def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
- // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order.
- scala.collection.immutable.ListMap(keys.zip(values): _*)
- }
-
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val strSeq = intSeq.map(_.toString)
+
checkEvaluation(CreateMap(Nil), Map.empty)
checkEvaluation(
CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))),
- createMap(intSeq, longSeq))
+ create_map(intSeq, longSeq))
checkEvaluation(
CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))),
- createMap(strSeq, longSeq))
+ create_map(strSeq, longSeq))
checkEvaluation(
CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))),
- createMap(longSeq, strSeq))
+ create_map(longSeq, strSeq))
val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType)
checkEvaluation(
CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)),
- createMap(intSeq, strWithNull.map(_.value)))
- intercept[RuntimeException] {
- checkEvaluationWithoutCodegen(
- CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
- null, null)
- }
- intercept[RuntimeException] {
- checkEvaluationWithUnsafeProjection(
- CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
- null, null)
- }
+ create_map(intSeq, strWithNull.map(_.value)))
+ // Map key can't be null
+ checkExceptionInExpression[RuntimeException](
+ CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
+ "Cannot use null as map key")
+
+ // ArrayType map key and value
val map = CreateMap(Seq(
Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)),
Literal.create(strSeq, ArrayType(StringType, containsNull = false)),
@@ -202,15 +194,21 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
ArrayType(IntegerType, containsNull = true),
ArrayType(StringType, containsNull = true),
valueContainsNull = false))
- checkEvaluation(map, createMap(Seq(intSeq, intSeq :+ null), Seq(strSeq, strSeq :+ null)))
+ checkEvaluation(map, create_map(intSeq -> strSeq, (intSeq :+ null) -> (strSeq :+ null)))
+
+ // map key can't be map
+ val map2 = CreateMap(Seq(
+ Literal.create(create_map(1 -> 1), MapType(IntegerType, IntegerType)),
+ Literal(1)
+ ))
+ map2.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key")
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("The key of map cannot be/contain map"))
+ }
}
test("MapFromArrays") {
- def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
- // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order.
- scala.collection.immutable.ListMap(keys.zip(values): _*)
- }
-
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val strSeq = intSeq.map(_.toString)
@@ -228,24 +226,33 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val nullArray = Literal.create(null, ArrayType(StringType, false))
- checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq))
- checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq))
- checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq))
+ checkEvaluation(MapFromArrays(intArray, longArray), create_map(intSeq, longSeq))
+ checkEvaluation(MapFromArrays(intArray, strArray), create_map(intSeq, strSeq))
+ checkEvaluation(MapFromArrays(integerArray, strArray), create_map(integerSeq, strSeq))
checkEvaluation(
- MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq))
+ MapFromArrays(strArray, intWithNullArray), create_map(strSeq, intWithNullSeq))
checkEvaluation(
- MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq))
+ MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq))
checkEvaluation(
- MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq))
+ MapFromArrays(strArray, longWithNullArray), create_map(strSeq, longWithNullSeq))
checkEvaluation(MapFromArrays(nullArray, nullArray), null)
- intercept[RuntimeException] {
- checkEvaluation(MapFromArrays(intWithNullArray, strArray), null)
- }
- intercept[RuntimeException] {
- checkEvaluation(
- MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null)
+ // Map key can't be null
+ checkExceptionInExpression[RuntimeException](
+ MapFromArrays(intWithNullArray, strArray),
+ "Cannot use null as map key")
+
+ // map key can't be map
+ val arrayOfMap = Seq(create_map(1 -> "a", 2 -> "b"))
+ val map = MapFromArrays(
+ Literal.create(arrayOfMap, ArrayType(MapType(IntegerType, StringType))),
+ Literal.create(Seq(1), ArrayType(IntegerType))
+ )
+ map.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key")
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("The key of map cannot be/contain map"))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
index d006197bd5678..98c93a4946f4f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql.catalyst.expressions
-import java.util.Calendar
+import java.text.SimpleDateFormat
+import java.util.{Calendar, Locale}
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.util._
@@ -209,4 +211,30 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P
"2015-12-31T16:00:00"
)
}
+
+ test("parse date with locale") {
+ Seq("en-US", "ru-RU").foreach { langTag =>
+ val locale = Locale.forLanguageTag(langTag)
+ val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05")
+ val schema = new StructType().add("d", DateType)
+ val dateFormat = "MMM yyyy"
+ val sdf = new SimpleDateFormat(dateFormat, locale)
+ val dateStr = sdf.format(date)
+ val options = Map("dateFormat" -> dateFormat, "locale" -> langTag)
+
+ checkEvaluation(
+ CsvToStructs(schema, options, Literal.create(dateStr), gmtId),
+ InternalRow(17836)) // number of days from 1970-01-01
+ }
+ }
+
+ test("verify corrupt column") {
+ checkExceptionInExpression[AnalysisException](
+ CsvToStructs(
+ schema = StructType.fromDDL("i int, _unparsed boolean"),
+ options = Map("columnNameOfCorruptRecord" -> "_unparsed"),
+ child = Literal.create("a"),
+ timeZoneId = gmtId),
+ expectedErrMsg = "The field for corrupt records must be string type and nullable")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index da18475276a13..eb33325d0b31a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
-import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -48,6 +48,25 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst))
}
+ // Currently MapData just stores the key and value arrays. Its equality is not well implemented,
+ // as the order of the map entries should not matter for equality. This method creates MapData
+ // with the entries ordering preserved, so that we can deterministically test expressions with
+ // map input/output.
+ protected def create_map(entries: (_, _)*): ArrayBasedMapData = {
+ create_map(entries.map(_._1), entries.map(_._2))
+ }
+
+ protected def create_map(keys: Seq[_], values: Seq[_]): ArrayBasedMapData = {
+ assert(keys.length == values.length)
+ val keyArray = CatalystTypeConverters
+ .convertToCatalyst(keys)
+ .asInstanceOf[ArrayData]
+ val valueArray = CatalystTypeConverters
+ .convertToCatalyst(values)
+ .asInstanceOf[ArrayData]
+ new ArrayBasedMapData(keyArray, valueArray)
+ }
+
private def prepareEvaluation(expression: Expression): Expression = {
val serializer = new JavaSerializer(new SparkConf()).newInstance
val resolver = ResolveTimeZone(new SQLConf)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
index e13f4d98295be..66bf18af95799 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types._
@@ -310,13 +311,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
test("TransformKeys") {
val ai0 = Literal.create(
- Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4),
+ create_map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val ai1 = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai2 = Literal.create(
- Map(1 -> 1, 2 -> null, 3 -> 3),
+ create_map(1 -> 1, 2 -> null, 3 -> 3),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))
@@ -324,26 +325,27 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
val plusValue: (Expression, Expression) => Expression = (k, v) => k + v
val modKey: (Expression, Expression) => Expression = (k, v) => k % 3
- checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4))
- checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4))
+ checkEvaluation(transformKeys(ai0, plusOne), create_map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4))
+ checkEvaluation(transformKeys(ai0, plusValue), create_map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4))
checkEvaluation(
- transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4))
+ transformKeys(transformKeys(ai0, plusOne), plusValue),
+ create_map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4))
checkEvaluation(transformKeys(ai0, modKey),
ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4)))
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(
transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int])
- checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3))
+ checkEvaluation(transformKeys(ai2, plusOne), create_map(2 -> 1, 3 -> null, 4 -> 3))
checkEvaluation(
- transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3))
+ transformKeys(transformKeys(ai2, plusOne), plusOne), create_map(3 -> 1, 4 -> null, 5 -> 3))
checkEvaluation(transformKeys(ai3, plusOne), null)
val as0 = Literal.create(
- Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
+ create_map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = false))
val as1 = Literal.create(
- Map("a" -> "xy", "bb" -> "yz", "ccc" -> null),
+ create_map("a" -> "xy", "bb" -> "yz", "ccc" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val as2 = Literal.create(null,
MapType(StringType, StringType, valueContainsNull = false))
@@ -355,26 +357,35 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
(k, v) => Length(k) + 1
checkEvaluation(
- transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx"))
+ transformKeys(as0, concatValue), create_map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx"))
checkEvaluation(
transformKeys(transformKeys(as0, concatValue), concatValue),
- Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx"))
+ create_map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx"))
checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String])
checkEvaluation(
transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength),
Map.empty[Int, String])
checkEvaluation(transformKeys(as0, convertKeyToKeyLength),
- Map(2 -> "xy", 3 -> "yz", 4 -> "zx"))
+ create_map(2 -> "xy", 3 -> "yz", 4 -> "zx"))
checkEvaluation(transformKeys(as1, convertKeyToKeyLength),
- Map(2 -> "xy", 3 -> "yz", 4 -> null))
+ create_map(2 -> "xy", 3 -> "yz", 4 -> null))
checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null)
checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, String])
val ax0 = Literal.create(
- Map(1 -> "x", 2 -> "y", 3 -> "z"),
+ create_map(1 -> "x", 2 -> "y", 3 -> "z"),
MapType(IntegerType, StringType, valueContainsNull = false))
- checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z"))
+ checkEvaluation(transformKeys(ax0, plusOne), create_map(2 -> "x", 3 -> "y", 4 -> "z"))
+
+ // map key can't be map
+ val makeMap: (Expression, Expression) => Expression = (k, v) => CreateMap(Seq(k, v))
+ val map = transformKeys(ai0, makeMap)
+ map.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key")
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("The key of map cannot be/contain map"))
+ }
}
test("TransformValues") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index 304642161146b..9b89a27c23770 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql.catalyst.expressions
-import java.util.Calendar
+import java.text.SimpleDateFormat
+import java.util.{Calendar, Locale}
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.PlanTestBase
@@ -546,7 +548,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId),
- null
+ InternalRow(null)
)
}
@@ -737,4 +739,30 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
CreateMap(Seq(Literal.create("allowNumericLeadingZeros"), Literal.create("true")))),
"struct")
}
+
+ test("parse date with locale") {
+ Seq("en-US", "ru-RU").foreach { langTag =>
+ val locale = Locale.forLanguageTag(langTag)
+ val date = new SimpleDateFormat("yyyy-MM-dd").parse("2018-11-05")
+ val schema = new StructType().add("d", DateType)
+ val dateFormat = "MMM yyyy"
+ val sdf = new SimpleDateFormat(dateFormat, locale)
+ val dateStr = s"""{"d":"${sdf.format(date)}"}"""
+ val options = Map("dateFormat" -> dateFormat, "locale" -> langTag)
+
+ checkEvaluation(
+ JsonToStructs(schema, options, Literal.create(dateStr), gmtId),
+ InternalRow(17836)) // number of days from 1970-01-01
+ }
+ }
+
+ test("verify corrupt column") {
+ checkExceptionInExpression[AnalysisException](
+ JsonToStructs(
+ schema = StructType.fromDDL("i int, _unparsed boolean"),
+ options = Map("columnNameOfCorruptRecord" -> "_unparsed"),
+ child = Literal.create("""{"i":"a"}"""),
+ timeZoneId = gmtId),
+ expectedErrMsg = "The field for corrupt records must be string type and nullable")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index d145fd0aaba47..436675bf50353 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, JavaTypeInference, ScalaReflection}
+import org.apache.spark.sql.catalyst.ScroogeLikeExample
import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer, UnresolvedDeserializer}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders._
@@ -307,7 +308,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val conf = new SparkConf()
Seq(true, false).foreach { useKryo =>
val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf)
- val expected = serializer.newInstance().serialize(new Integer(1)).array()
+ val expected = serializer.newInstance().serialize(Integer.valueOf(1)).array()
val encodeUsingSerializer = EncodeUsingSerializer(inputObject, useKryo)
checkEvaluation(encodeUsingSerializer, expected, InternalRow.fromSeq(Seq(1)))
checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null)))
@@ -384,9 +385,9 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val conf = new SparkConf()
Seq(true, false).foreach { useKryo =>
val serializer = if (useKryo) new KryoSerializer(conf) else new JavaSerializer(conf)
- val input = serializer.newInstance().serialize(new Integer(1)).array()
+ val input = serializer.newInstance().serialize(Integer.valueOf(1)).array()
val decodeUsingSerializer = DecodeUsingSerializer(inputObject, ClassTag(cls), useKryo)
- checkEvaluation(decodeUsingSerializer, new Integer(1), InternalRow.fromSeq(Seq(input)))
+ checkEvaluation(decodeUsingSerializer, Integer.valueOf(1), InternalRow.fromSeq(Seq(input)))
checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null)))
}
}
@@ -410,6 +411,15 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
dataType = ObjectType(classOf[outerObj.Inner]),
outerPointer = Some(() => outerObj))
checkObjectExprEvaluation(newInst2, new outerObj.Inner(1))
+
+ // SPARK-8288: A class with only a companion object constructor
+ val newInst3 = NewInstance(
+ cls = classOf[ScroogeLikeExample],
+ arguments = Literal(1) :: Nil,
+ propagateNull = false,
+ dataType = ObjectType(classOf[ScroogeLikeExample]),
+ outerPointer = Some(() => outerObj))
+ checkObjectExprEvaluation(newInst3, ScroogeLikeExample(1))
}
test("LambdaVariable should support interpreted execution") {
@@ -575,7 +585,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// NULL key test
val scalaMapHasNullKey = scala.collection.Map[java.lang.Integer, String](
- null.asInstanceOf[java.lang.Integer] -> "v0", new java.lang.Integer(1) -> "v1")
+ null.asInstanceOf[java.lang.Integer] -> "v0", java.lang.Integer.valueOf(1) -> "v1")
val javaMapHasNullKey = new java.util.HashMap[java.lang.Integer, java.lang.String]() {
{
put(null, "v0")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala
new file mode 100644
index 0000000000000..8e9c9972071ad
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AggregateExpressionSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.{Add, AttributeSet}
+
+class AggregateExpressionSuite extends SparkFunSuite {
+
+ test("test references from unresolved aggregate functions") {
+ val x = UnresolvedAttribute("x")
+ val y = UnresolvedAttribute("y")
+ val actual = AggregateExpression(Sum(Add(x, y)), mode = Complete, isDistinct = false).references
+ val expected = AttributeSet(x :: y :: Nil)
+ assert(expected == actual, s"Expected: $expected. Actual: $actual")
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
index 294fce8e9a10f..63c7b42978025 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -41,9 +41,9 @@ class PercentileSuite extends SparkFunSuite {
val buffer = new OpenHashMap[AnyRef, Long]()
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
- // Check non-empty buffer serializa and deserialize.
+ // Check non-empty buffer serialize and deserialize.
data.foreach { key =>
- buffer.changeValue(new Integer(key), 1L, _ + 1L)
+ buffer.changeValue(Integer.valueOf(key), 1L, _ + 1L)
}
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala
new file mode 100644
index 0000000000000..3f1c91df7f2e9
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.PythonUDF
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.internal.SQLConf._
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+class PullOutPythonUDFInJoinConditionSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Extract PythonUDF From JoinCondition", Once,
+ PullOutPythonUDFInJoinCondition) ::
+ Batch("Check Cartesian Products", Once,
+ CheckCartesianProducts) :: Nil
+ }
+
+ val attrA = 'a.int
+ val attrB = 'b.int
+ val attrC = 'c.int
+ val attrD = 'd.int
+
+ val testRelationLeft = LocalRelation(attrA, attrB)
+ val testRelationRight = LocalRelation(attrC, attrD)
+
+ // This join condition refers to attributes from 2 tables, but the PythonUDF inside it only
+ // refer to attributes from one side.
+ val evaluableJoinCond = {
+ val pythonUDF = PythonUDF("evaluable", null,
+ IntegerType,
+ Seq(attrA),
+ PythonEvalType.SQL_BATCHED_UDF,
+ udfDeterministic = true)
+ pythonUDF === attrC
+ }
+
+ // This join condition is a PythonUDF which refers to attributes from 2 tables.
+ val unevaluableJoinCond = PythonUDF("unevaluable", null,
+ BooleanType,
+ Seq(attrA, attrC),
+ PythonEvalType.SQL_BATCHED_UDF,
+ udfDeterministic = true)
+
+ val unsupportedJoinTypes = Seq(LeftOuter, RightOuter, FullOuter, LeftAnti)
+
+ private def comparePlanWithCrossJoinEnable(query: LogicalPlan, expected: LogicalPlan): Unit = {
+ // AnalysisException thrown by CheckCartesianProducts while spark.sql.crossJoin.enabled=false
+ val exception = intercept[AnalysisException] {
+ Optimize.execute(query.analyze)
+ }
+ assert(exception.message.startsWith("Detected implicit cartesian product"))
+
+ // pull out the python udf while set spark.sql.crossJoin.enabled=true
+ withSQLConf(CROSS_JOINS_ENABLED.key -> "true") {
+ val optimized = Optimize.execute(query.analyze)
+ comparePlans(optimized, expected)
+ }
+ }
+
+ test("inner join condition with python udf") {
+ val query1 = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = Some(unevaluableJoinCond))
+ val expected1 = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = None).where(unevaluableJoinCond).analyze
+ comparePlanWithCrossJoinEnable(query1, expected1)
+
+ // evaluable PythonUDF will not be touched
+ val query2 = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = Some(evaluableJoinCond))
+ comparePlans(Optimize.execute(query2), query2)
+ }
+
+ test("left semi join condition with python udf") {
+ val query1 = testRelationLeft.join(
+ testRelationRight,
+ joinType = LeftSemi,
+ condition = Some(unevaluableJoinCond))
+ val expected1 = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = None).where(unevaluableJoinCond).select('a, 'b).analyze
+ comparePlanWithCrossJoinEnable(query1, expected1)
+
+ // evaluable PythonUDF will not be touched
+ val query2 = testRelationLeft.join(
+ testRelationRight,
+ joinType = LeftSemi,
+ condition = Some(evaluableJoinCond))
+ comparePlans(Optimize.execute(query2), query2)
+ }
+
+ test("unevaluable python udf and common condition") {
+ val query = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr))
+ val expected = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze
+ val optimized = Optimize.execute(query.analyze)
+ comparePlans(optimized, expected)
+ }
+
+ test("unevaluable python udf or common condition") {
+ val query = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr))
+ val expected = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze
+ comparePlanWithCrossJoinEnable(query, expected)
+ }
+
+ test("pull out whole complex condition with multiple unevaluable python udf") {
+ val pythonUDF1 = PythonUDF("pythonUDF1", null,
+ BooleanType,
+ Seq(attrA, attrC),
+ PythonEvalType.SQL_BATCHED_UDF,
+ udfDeterministic = true)
+ val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1
+
+ val query = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = Some(condition))
+ val expected = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = None).where(condition).analyze
+ comparePlanWithCrossJoinEnable(query, expected)
+ }
+
+ test("partial pull out complex condition with multiple unevaluable python udf") {
+ val pythonUDF1 = PythonUDF("pythonUDF1", null,
+ BooleanType,
+ Seq(attrA, attrC),
+ PythonEvalType.SQL_BATCHED_UDF,
+ udfDeterministic = true)
+ val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr
+
+ val query = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = Some(condition))
+ val expected = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze
+ val optimized = Optimize.execute(query.analyze)
+ comparePlans(optimized, expected)
+ }
+
+ test("pull out unevaluable python udf when it's mixed with evaluable one") {
+ val query = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = Some(evaluableJoinCond && unevaluableJoinCond))
+ val expected = testRelationLeft.join(
+ testRelationRight,
+ joinType = Inner,
+ condition = Some(evaluableJoinCond)).where(unevaluableJoinCond).analyze
+ val optimized = Optimize.execute(query.analyze)
+ comparePlans(optimized, expected)
+ }
+
+ test("throw an exception for not support join type") {
+ for (joinType <- unsupportedJoinTypes) {
+ val e = intercept[AnalysisException] {
+ val query = testRelationLeft.join(
+ testRelationRight,
+ joinType,
+ condition = Some(unevaluableJoinCond))
+ Optimize.execute(query.analyze)
+ }
+ assert(e.message.contentEquals(
+ s"Using PythonUDF in join condition of join type $joinType is not supported."))
+
+ val query2 = testRelationLeft.join(
+ testRelationRight,
+ joinType,
+ condition = Some(evaluableJoinCond))
+ comparePlans(Optimize.execute(query2), query2)
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
similarity index 84%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
index c6b5d0ec96776..ee0d04da3e46c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or}
+import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{BooleanType, IntegerType}
-class ReplaceNullWithFalseSuite extends PlanTest {
+class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
@@ -36,15 +36,23 @@ class ReplaceNullWithFalseSuite extends PlanTest {
ConstantFolding,
BooleanSimplification,
SimplifyConditionals,
- ReplaceNullWithFalse) :: Nil
+ ReplaceNullWithFalseInPredicate) :: Nil
}
- private val testRelation = LocalRelation('i.int, 'b.boolean)
+ private val testRelation =
+ LocalRelation('i.int, 'b.boolean, 'a.array(IntegerType), 'm.map(IntegerType, IntegerType))
private val anotherTestRelation = LocalRelation('d.int)
test("replace null inside filter and join conditions") {
- testFilter(originalCond = Literal(null), expectedCond = FalseLiteral)
- testJoin(originalCond = Literal(null), expectedCond = FalseLiteral)
+ testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
+ testJoin(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral)
+ }
+
+ test("Not expected type - replaceNullWithFalse") {
+ val e = intercept[IllegalArgumentException] {
+ testFilter(originalCond = Literal(null, IntegerType), expectedCond = FalseLiteral)
+ }.getMessage
+ assert(e.contains("but got the type `int` in `CAST(NULL AS INT)"))
}
test("replace null in branches of If") {
@@ -298,6 +306,26 @@ class ReplaceNullWithFalseSuite extends PlanTest {
testProjection(originalExpr = column, expectedExpr = column)
}
+ test("replace nulls in lambda function of ArrayFilter") {
+ testHigherOrderFunc('a, ArrayFilter, Seq('e))
+ }
+
+ test("replace nulls in lambda function of ArrayExists") {
+ testHigherOrderFunc('a, ArrayExists, Seq('e))
+ }
+
+ test("replace nulls in lambda function of MapFilter") {
+ testHigherOrderFunc('m, MapFilter, Seq('k, 'v))
+ }
+
+ test("inability to replace nulls in arbitrary higher-order function") {
+ val lambdaFunc = LambdaFunction(
+ function = If('e > 0, Literal(null, BooleanType), TrueLiteral),
+ arguments = Seq[NamedExpression]('e))
+ val column = ArrayTransform('a, lambdaFunc)
+ testProjection(originalExpr = column, expectedExpr = column)
+ }
+
private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = {
test((rel, exp) => rel.where(exp), originalCond, expectedCond)
}
@@ -310,6 +338,25 @@ class ReplaceNullWithFalseSuite extends PlanTest {
test((rel, exp) => rel.select(exp), originalExpr, expectedExpr)
}
+ private def testHigherOrderFunc(
+ argument: Expression,
+ createExpr: (Expression, Expression) => Expression,
+ lambdaArgs: Seq[NamedExpression]): Unit = {
+ val condArg = lambdaArgs.last
+ // the lambda body is: if(arg > 0, null, true)
+ val cond = GreaterThan(condArg, Literal(0))
+ val lambda1 = LambdaFunction(
+ function = If(cond, Literal(null, BooleanType), TrueLiteral),
+ arguments = lambdaArgs)
+ // the optimized lambda body is: if(arg > 0, false, true)
+ val lambda2 = LambdaFunction(
+ function = If(cond, FalseLiteral, TrueLiteral),
+ arguments = lambdaArgs)
+ testProjection(
+ originalExpr = createExpr(argument, lambda1) as 'x,
+ expectedExpr = createExpr(argument, lambda2) as 'x)
+ }
+
private def test(
func: (LogicalPlan, Expression) => LogicalPlan,
originalExpr: Expression,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index da3923f8d6477..17e00c9a3ead2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, GreaterThan, GreaterThanOrEqual, If, Literal, ReplicateRows}
+import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, GreaterThanOrEqual, If, Literal, Rand, ReplicateRows}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -196,4 +196,31 @@ class SetOperationSuite extends PlanTest {
))
comparePlans(expectedPlan, rewrittenPlan)
}
+
+ test("SPARK-23356 union: expressions with literal in project list are pushed down") {
+ val unionQuery = testUnion.select(('a + 1).as("aa"))
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ val unionCorrectAnswer =
+ Union(testRelation.select(('a + 1).as("aa")) ::
+ testRelation2.select(('d + 1).as("aa")) ::
+ testRelation3.select(('g + 1).as("aa")) :: Nil).analyze
+ comparePlans(unionOptimized, unionCorrectAnswer)
+ }
+
+ test("SPARK-23356 union: expressions in project list are pushed down") {
+ val unionQuery = testUnion.select(('a + 'b).as("ab"))
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ val unionCorrectAnswer =
+ Union(testRelation.select(('a + 'b).as("ab")) ::
+ testRelation2.select(('d + 'e).as("ab")) ::
+ testRelation3.select(('g + 'h).as("ab")) :: Nil).analyze
+ comparePlans(unionOptimized, unionCorrectAnswer)
+ }
+
+ test("SPARK-23356 union: no pushdown for non-deterministic expression") {
+ val unionQuery = testUnion.select('a, Rand(10).as("rnd"))
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ val unionCorrectAnswer = unionQuery.analyze
+ comparePlans(unionOptimized, unionCorrectAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 10de90c6a44ca..8abd7625c21aa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -228,4 +228,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
val decimal = Decimal.apply(bigInt)
assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808")
}
+
+ test("SPARK-26038: toScalaBigInt/toJavaBigInteger") {
+ // not fitting long
+ val decimal = Decimal("1234568790123456789012348790.1234879012345678901234568790")
+ assert(decimal.toScalaBigInt == scala.math.BigInt("1234568790123456789012348790"))
+ assert(decimal.toJavaBigInteger == new java.math.BigInteger("1234568790123456789012348790"))
+ // fitting long
+ val decimalLong = Decimal(123456789123456789L, 18, 9)
+ assert(decimalLong.toScalaBigInt == scala.math.BigInt("123456789"))
+ assert(decimalLong.toJavaBigInteger == new java.math.BigInteger("123456789"))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala
new file mode 100644
index 0000000000000..9c162026942f6
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/UtilSuite.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.truncatedString
+
+class UtilSuite extends SparkFunSuite {
+ test("truncatedString") {
+ assert(truncatedString(Nil, "[", ", ", "]", 2) == "[]")
+ assert(truncatedString(Seq(1, 2), "[", ", ", "]", 2) == "[1, 2]")
+ assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", 2) == "[1, ... 2 more fields]")
+ assert(truncatedString(Seq(1, 2, 3), "[", ", ", "]", -5) == "[, ... 3 more fields]")
+ assert(truncatedString(Seq(1, 2, 3), ", ") == "1, 2, 3")
+ }
+}
diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt
index 2d3bae442cc50..b07e8b1197ff0 100644
--- a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt
+++ b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt
@@ -2,268 +2,268 @@
SQL Single Numeric Column Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
SQL Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 21508 / 22112 0.7 1367.5 1.0X
-SQL Json 8705 / 8825 1.8 553.4 2.5X
-SQL Parquet Vectorized 157 / 186 100.0 10.0 136.7X
-SQL Parquet MR 1789 / 1794 8.8 113.8 12.0X
-SQL ORC Vectorized 156 / 166 100.9 9.9 138.0X
-SQL ORC Vectorized with copy 218 / 225 72.1 13.9 98.6X
-SQL ORC MR 1448 / 1492 10.9 92.0 14.9X
-
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+SQL CSV 26366 / 26562 0.6 1676.3 1.0X
+SQL Json 8709 / 8724 1.8 553.7 3.0X
+SQL Parquet Vectorized 166 / 187 94.8 10.5 159.0X
+SQL Parquet MR 1706 / 1720 9.2 108.4 15.5X
+SQL ORC Vectorized 167 / 174 94.2 10.6 157.9X
+SQL ORC Vectorized with copy 226 / 231 69.6 14.4 116.7X
+SQL ORC MR 1433 / 1465 11.0 91.1 18.4X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Parquet Reader Single TINYINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-ParquetReader Vectorized 202 / 211 77.7 12.9 1.0X
-ParquetReader Vectorized -> Row 118 / 120 133.5 7.5 1.7X
+ParquetReader Vectorized 200 / 207 78.7 12.7 1.0X
+ParquetReader Vectorized -> Row 117 / 119 134.7 7.4 1.7X
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
SQL Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 23282 / 23312 0.7 1480.2 1.0X
-SQL Json 9187 / 9189 1.7 584.1 2.5X
-SQL Parquet Vectorized 204 / 218 77.0 13.0 114.0X
-SQL Parquet MR 1941 / 1953 8.1 123.4 12.0X
-SQL ORC Vectorized 217 / 225 72.6 13.8 107.5X
-SQL ORC Vectorized with copy 279 / 289 56.3 17.8 83.4X
-SQL ORC MR 1541 / 1549 10.2 98.0 15.1X
-
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+SQL CSV 26489 / 26547 0.6 1684.1 1.0X
+SQL Json 8990 / 8998 1.7 571.5 2.9X
+SQL Parquet Vectorized 209 / 221 75.1 13.3 126.5X
+SQL Parquet MR 1949 / 1949 8.1 123.9 13.6X
+SQL ORC Vectorized 221 / 228 71.3 14.0 120.1X
+SQL ORC Vectorized with copy 315 / 319 49.9 20.1 84.0X
+SQL ORC MR 1527 / 1549 10.3 97.1 17.3X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Parquet Reader Single SMALLINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-ParquetReader Vectorized 288 / 297 54.6 18.3 1.0X
-ParquetReader Vectorized -> Row 255 / 257 61.7 16.2 1.1X
+ParquetReader Vectorized 286 / 296 54.9 18.2 1.0X
+ParquetReader Vectorized -> Row 249 / 253 63.1 15.8 1.1X
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
SQL Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 24990 / 25012 0.6 1588.8 1.0X
-SQL Json 9837 / 9865 1.6 625.4 2.5X
-SQL Parquet Vectorized 170 / 180 92.3 10.8 146.6X
-SQL Parquet MR 2319 / 2328 6.8 147.4 10.8X
-SQL ORC Vectorized 293 / 301 53.7 18.6 85.3X
-SQL ORC Vectorized with copy 297 / 309 52.9 18.9 84.0X
-SQL ORC MR 1667 / 1674 9.4 106.0 15.0X
-
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+SQL CSV 27701 / 27744 0.6 1761.2 1.0X
+SQL Json 9703 / 9733 1.6 616.9 2.9X
+SQL Parquet Vectorized 176 / 182 89.2 11.2 157.0X
+SQL Parquet MR 2164 / 2173 7.3 137.6 12.8X
+SQL ORC Vectorized 307 / 314 51.2 19.5 90.2X
+SQL ORC Vectorized with copy 312 / 319 50.4 19.8 88.7X
+SQL ORC MR 1690 / 1700 9.3 107.4 16.4X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Parquet Reader Single INT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-ParquetReader Vectorized 257 / 274 61.3 16.3 1.0X
-ParquetReader Vectorized -> Row 259 / 264 60.8 16.4 1.0X
+ParquetReader Vectorized 259 / 277 60.7 16.5 1.0X
+ParquetReader Vectorized -> Row 261 / 265 60.3 16.6 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
SQL Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 32537 / 32554 0.5 2068.7 1.0X
-SQL Json 12610 / 12668 1.2 801.7 2.6X
-SQL Parquet Vectorized 258 / 276 61.0 16.4 126.2X
-SQL Parquet MR 2422 / 2435 6.5 154.0 13.4X
-SQL ORC Vectorized 378 / 385 41.6 24.0 86.2X
-SQL ORC Vectorized with copy 381 / 389 41.3 24.2 85.4X
-SQL ORC MR 1797 / 1819 8.8 114.3 18.1X
-
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+SQL CSV 34813 / 34900 0.5 2213.3 1.0X
+SQL Json 12570 / 12617 1.3 799.2 2.8X
+SQL Parquet Vectorized 270 / 308 58.2 17.2 128.9X
+SQL Parquet MR 2427 / 2431 6.5 154.3 14.3X
+SQL ORC Vectorized 388 / 398 40.6 24.6 89.8X
+SQL ORC Vectorized with copy 395 / 402 39.9 25.1 88.2X
+SQL ORC MR 1819 / 1851 8.6 115.7 19.1X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Parquet Reader Single BIGINT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-ParquetReader Vectorized 352 / 368 44.7 22.4 1.0X
-ParquetReader Vectorized -> Row 351 / 359 44.8 22.3 1.0X
+ParquetReader Vectorized 372 / 379 42.3 23.7 1.0X
+ParquetReader Vectorized -> Row 357 / 368 44.1 22.7 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
SQL Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 27179 / 27184 0.6 1728.0 1.0X
-SQL Json 12578 / 12585 1.3 799.7 2.2X
-SQL Parquet Vectorized 161 / 171 97.5 10.3 168.5X
-SQL Parquet MR 2361 / 2395 6.7 150.1 11.5X
-SQL ORC Vectorized 473 / 480 33.3 30.0 57.5X
-SQL ORC Vectorized with copy 478 / 483 32.9 30.4 56.8X
-SQL ORC MR 1858 / 1859 8.5 118.2 14.6X
-
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+SQL CSV 28753 / 28781 0.5 1828.0 1.0X
+SQL Json 12039 / 12215 1.3 765.4 2.4X
+SQL Parquet Vectorized 170 / 177 92.4 10.8 169.0X
+SQL Parquet MR 2184 / 2196 7.2 138.9 13.2X
+SQL ORC Vectorized 432 / 440 36.4 27.5 66.5X
+SQL ORC Vectorized with copy 439 / 442 35.9 27.9 65.6X
+SQL ORC MR 1812 / 1833 8.7 115.2 15.9X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Parquet Reader Single FLOAT Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-ParquetReader Vectorized 251 / 255 62.7 15.9 1.0X
-ParquetReader Vectorized -> Row 255 / 259 61.8 16.2 1.0X
+ParquetReader Vectorized 253 / 260 62.2 16.1 1.0X
+ParquetReader Vectorized -> Row 256 / 257 61.6 16.2 1.0X
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
SQL Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 34797 / 34830 0.5 2212.3 1.0X
-SQL Json 17806 / 17828 0.9 1132.1 2.0X
-SQL Parquet Vectorized 260 / 269 60.6 16.5 134.0X
-SQL Parquet MR 2512 / 2534 6.3 159.7 13.9X
-SQL ORC Vectorized 582 / 593 27.0 37.0 59.8X
-SQL ORC Vectorized with copy 576 / 584 27.3 36.6 60.4X
-SQL ORC MR 2309 / 2313 6.8 146.8 15.1X
-
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+SQL CSV 36177 / 36188 0.4 2300.1 1.0X
+SQL Json 18895 / 18898 0.8 1201.3 1.9X
+SQL Parquet Vectorized 267 / 276 58.9 17.0 135.6X
+SQL Parquet MR 2355 / 2363 6.7 149.7 15.4X
+SQL ORC Vectorized 543 / 546 29.0 34.5 66.6X
+SQL ORC Vectorized with copy 548 / 557 28.7 34.8 66.0X
+SQL ORC MR 2246 / 2258 7.0 142.8 16.1X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Parquet Reader Single DOUBLE Column Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-ParquetReader Vectorized 350 / 363 44.9 22.3 1.0X
-ParquetReader Vectorized -> Row 350 / 366 44.9 22.3 1.0X
+ParquetReader Vectorized 353 / 367 44.6 22.4 1.0X
+ParquetReader Vectorized -> Row 351 / 357 44.7 22.3 1.0X
================================================================================================
Int and String Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Int and String Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 22486 / 22590 0.5 2144.5 1.0X
-SQL Json 14124 / 14195 0.7 1347.0 1.6X
-SQL Parquet Vectorized 2342 / 2347 4.5 223.4 9.6X
-SQL Parquet MR 4660 / 4664 2.2 444.4 4.8X
-SQL ORC Vectorized 2378 / 2379 4.4 226.8 9.5X
-SQL ORC Vectorized with copy 2548 / 2571 4.1 243.0 8.8X
-SQL ORC MR 4206 / 4211 2.5 401.1 5.3X
+SQL CSV 21130 / 21246 0.5 2015.1 1.0X
+SQL Json 12145 / 12174 0.9 1158.2 1.7X
+SQL Parquet Vectorized 2363 / 2377 4.4 225.3 8.9X
+SQL Parquet MR 4555 / 4557 2.3 434.4 4.6X
+SQL ORC Vectorized 2361 / 2388 4.4 225.1 9.0X
+SQL ORC Vectorized with copy 2540 / 2557 4.1 242.2 8.3X
+SQL ORC MR 4186 / 4209 2.5 399.2 5.0X
================================================================================================
Repeated String Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Repeated String: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 12150 / 12178 0.9 1158.7 1.0X
-SQL Json 7012 / 7014 1.5 668.7 1.7X
-SQL Parquet Vectorized 792 / 796 13.2 75.5 15.3X
-SQL Parquet MR 1961 / 1975 5.3 187.0 6.2X
-SQL ORC Vectorized 482 / 485 21.8 46.0 25.2X
-SQL ORC Vectorized with copy 710 / 715 14.8 67.7 17.1X
-SQL ORC MR 2081 / 2083 5.0 198.5 5.8X
+SQL CSV 11693 / 11729 0.9 1115.1 1.0X
+SQL Json 7025 / 7025 1.5 669.9 1.7X
+SQL Parquet Vectorized 803 / 821 13.1 76.6 14.6X
+SQL Parquet MR 1776 / 1790 5.9 169.4 6.6X
+SQL ORC Vectorized 491 / 494 21.4 46.8 23.8X
+SQL ORC Vectorized with copy 723 / 725 14.5 68.9 16.2X
+SQL ORC MR 2050 / 2063 5.1 195.5 5.7X
================================================================================================
Partitioned Table Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Partitioned Table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-Data column - CSV 31789 / 31791 0.5 2021.1 1.0X
-Data column - Json 12873 / 12918 1.2 818.4 2.5X
-Data column - Parquet Vectorized 267 / 280 58.9 17.0 119.1X
-Data column - Parquet MR 3387 / 3402 4.6 215.3 9.4X
-Data column - ORC Vectorized 391 / 453 40.2 24.9 81.2X
-Data column - ORC Vectorized with copy 392 / 398 40.2 24.9 81.2X
-Data column - ORC MR 2508 / 2512 6.3 159.4 12.7X
-Partition column - CSV 6965 / 6977 2.3 442.8 4.6X
-Partition column - Json 5563 / 5576 2.8 353.7 5.7X
-Partition column - Parquet Vectorized 65 / 78 241.1 4.1 487.2X
-Partition column - Parquet MR 1811 / 1811 8.7 115.1 17.6X
-Partition column - ORC Vectorized 66 / 73 239.0 4.2 483.0X
-Partition column - ORC Vectorized with copy 65 / 70 241.1 4.1 487.3X
-Partition column - ORC MR 1775 / 1778 8.9 112.8 17.9X
-Both columns - CSV 30032 / 30113 0.5 1909.4 1.1X
-Both columns - Json 13941 / 13959 1.1 886.3 2.3X
-Both columns - Parquet Vectorized 312 / 330 50.3 19.9 101.7X
-Both columns - Parquet MR 3858 / 3862 4.1 245.3 8.2X
-Both columns - ORC Vectorized 431 / 437 36.5 27.4 73.8X
-Both column - ORC Vectorized with copy 523 / 529 30.1 33.3 60.7X
-Both columns - ORC MR 2712 / 2805 5.8 172.4 11.7X
+Data column - CSV 30965 / 31041 0.5 1968.7 1.0X
+Data column - Json 12876 / 12882 1.2 818.6 2.4X
+Data column - Parquet Vectorized 277 / 282 56.7 17.6 111.6X
+Data column - Parquet MR 3398 / 3402 4.6 216.0 9.1X
+Data column - ORC Vectorized 399 / 407 39.4 25.4 77.5X
+Data column - ORC Vectorized with copy 407 / 447 38.6 25.9 76.0X
+Data column - ORC MR 2583 / 2589 6.1 164.2 12.0X
+Partition column - CSV 7403 / 7427 2.1 470.7 4.2X
+Partition column - Json 5587 / 5625 2.8 355.2 5.5X
+Partition column - Parquet Vectorized 71 / 78 222.6 4.5 438.3X
+Partition column - Parquet MR 1798 / 1808 8.7 114.3 17.2X
+Partition column - ORC Vectorized 72 / 75 219.0 4.6 431.2X
+Partition column - ORC Vectorized with copy 71 / 77 221.1 4.5 435.4X
+Partition column - ORC MR 1772 / 1778 8.9 112.6 17.5X
+Both columns - CSV 30211 / 30212 0.5 1920.7 1.0X
+Both columns - Json 13382 / 13391 1.2 850.8 2.3X
+Both columns - Parquet Vectorized 321 / 333 49.0 20.4 96.4X
+Both columns - Parquet MR 3656 / 3661 4.3 232.4 8.5X
+Both columns - ORC Vectorized 443 / 448 35.5 28.2 69.9X
+Both column - ORC Vectorized with copy 527 / 533 29.9 33.5 58.8X
+Both columns - ORC MR 2626 / 2633 6.0 167.0 11.8X
================================================================================================
String with Nulls Scan
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+String with Nulls Scan (0.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 13525 / 13823 0.8 1289.9 1.0X
-SQL Json 9913 / 9921 1.1 945.3 1.4X
-SQL Parquet Vectorized 1517 / 1517 6.9 144.7 8.9X
-SQL Parquet MR 3996 / 4008 2.6 381.1 3.4X
-ParquetReader Vectorized 1120 / 1128 9.4 106.8 12.1X
-SQL ORC Vectorized 1203 / 1224 8.7 114.7 11.2X
-SQL ORC Vectorized with copy 1639 / 1646 6.4 156.3 8.3X
-SQL ORC MR 3720 / 3780 2.8 354.7 3.6X
-
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+SQL CSV 13918 / 13979 0.8 1327.3 1.0X
+SQL Json 10068 / 10068 1.0 960.1 1.4X
+SQL Parquet Vectorized 1563 / 1564 6.7 149.0 8.9X
+SQL Parquet MR 3835 / 3836 2.7 365.8 3.6X
+ParquetReader Vectorized 1115 / 1118 9.4 106.4 12.5X
+SQL ORC Vectorized 1172 / 1208 8.9 111.8 11.9X
+SQL ORC Vectorized with copy 1630 / 1644 6.4 155.5 8.5X
+SQL ORC MR 3708 / 3711 2.8 353.6 3.8X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+String with Nulls Scan (50.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 15860 / 15877 0.7 1512.5 1.0X
-SQL Json 7676 / 7688 1.4 732.0 2.1X
-SQL Parquet Vectorized 1072 / 1084 9.8 102.2 14.8X
-SQL Parquet MR 2890 / 2897 3.6 275.6 5.5X
-ParquetReader Vectorized 1052 / 1053 10.0 100.4 15.1X
-SQL ORC Vectorized 1248 / 1248 8.4 119.0 12.7X
-SQL ORC Vectorized with copy 1627 / 1637 6.4 155.2 9.7X
-SQL ORC MR 3365 / 3369 3.1 320.9 4.7X
-
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+SQL CSV 13972 / 14043 0.8 1332.5 1.0X
+SQL Json 7436 / 7469 1.4 709.1 1.9X
+SQL Parquet Vectorized 1103 / 1112 9.5 105.2 12.7X
+SQL Parquet MR 2841 / 2847 3.7 271.0 4.9X
+ParquetReader Vectorized 992 / 1012 10.6 94.6 14.1X
+SQL ORC Vectorized 1275 / 1349 8.2 121.6 11.0X
+SQL ORC Vectorized with copy 1631 / 1644 6.4 155.5 8.6X
+SQL ORC MR 3244 / 3259 3.2 309.3 4.3X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
-String with Nulls Scan: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+String with Nulls Scan (95.0%): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 13401 / 13561 0.8 1278.1 1.0X
-SQL Json 5253 / 5303 2.0 500.9 2.6X
-SQL Parquet Vectorized 233 / 242 45.0 22.2 57.6X
-SQL Parquet MR 1791 / 1796 5.9 170.8 7.5X
-ParquetReader Vectorized 236 / 238 44.4 22.5 56.7X
-SQL ORC Vectorized 453 / 473 23.2 43.2 29.6X
-SQL ORC Vectorized with copy 573 / 577 18.3 54.7 23.4X
-SQL ORC MR 1846 / 1850 5.7 176.0 7.3X
+SQL CSV 11228 / 11244 0.9 1070.8 1.0X
+SQL Json 5200 / 5247 2.0 495.9 2.2X
+SQL Parquet Vectorized 238 / 242 44.1 22.7 47.2X
+SQL Parquet MR 1730 / 1734 6.1 165.0 6.5X
+ParquetReader Vectorized 237 / 238 44.3 22.6 47.4X
+SQL ORC Vectorized 459 / 462 22.8 43.8 24.4X
+SQL ORC Vectorized with copy 581 / 583 18.1 55.4 19.3X
+SQL ORC MR 1767 / 1783 5.9 168.5 6.4X
================================================================================================
Single Column Scan From Wide Columns
================================================================================================
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Single Column Scan from 10 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 3147 / 3148 0.3 3001.1 1.0X
-SQL Json 2666 / 2693 0.4 2542.9 1.2X
-SQL Parquet Vectorized 54 / 58 19.5 51.3 58.5X
-SQL Parquet MR 220 / 353 4.8 209.9 14.3X
-SQL ORC Vectorized 63 / 77 16.8 59.7 50.3X
-SQL ORC Vectorized with copy 63 / 66 16.7 59.8 50.2X
-SQL ORC MR 317 / 321 3.3 302.2 9.9X
-
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+SQL CSV 3322 / 3356 0.3 3167.9 1.0X
+SQL Json 2808 / 2843 0.4 2678.2 1.2X
+SQL Parquet Vectorized 56 / 63 18.9 52.9 59.8X
+SQL Parquet MR 215 / 219 4.9 205.4 15.4X
+SQL ORC Vectorized 64 / 76 16.4 60.9 52.0X
+SQL ORC Vectorized with copy 64 / 67 16.3 61.3 51.7X
+SQL ORC MR 314 / 316 3.3 299.6 10.6X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Single Column Scan from 50 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 7902 / 7921 0.1 7536.2 1.0X
-SQL Json 9467 / 9491 0.1 9028.6 0.8X
-SQL Parquet Vectorized 73 / 79 14.3 69.8 108.0X
-SQL Parquet MR 239 / 247 4.4 228.0 33.1X
-SQL ORC Vectorized 78 / 84 13.4 74.6 101.0X
-SQL ORC Vectorized with copy 78 / 88 13.4 74.4 101.3X
-SQL ORC MR 910 / 918 1.2 867.6 8.7X
-
-OpenJDK 64-Bit Server VM 1.8.0_181-b13 on Linux 3.10.0-862.3.2.el7.x86_64
+SQL CSV 7978 / 7989 0.1 7608.5 1.0X
+SQL Json 10294 / 10325 0.1 9816.9 0.8X
+SQL Parquet Vectorized 72 / 85 14.5 69.0 110.3X
+SQL Parquet MR 237 / 241 4.4 226.4 33.6X
+SQL ORC Vectorized 82 / 92 12.7 78.5 97.0X
+SQL ORC Vectorized with copy 82 / 88 12.7 78.5 97.0X
+SQL ORC MR 900 / 909 1.2 858.5 8.9X
+
+OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Single Column Scan from 100 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-SQL CSV 13539 / 13543 0.1 12912.0 1.0X
-SQL Json 17420 / 17446 0.1 16613.1 0.8X
-SQL Parquet Vectorized 103 / 120 10.2 98.1 131.6X
-SQL Parquet MR 250 / 258 4.2 238.9 54.1X
-SQL ORC Vectorized 99 / 104 10.6 94.6 136.5X
-SQL ORC Vectorized with copy 100 / 106 10.5 95.6 135.1X
-SQL ORC MR 1653 / 1659 0.6 1576.3 8.2X
+SQL CSV 13489 / 13508 0.1 12864.3 1.0X
+SQL Json 18813 / 18827 0.1 17941.4 0.7X
+SQL Parquet Vectorized 107 / 111 9.8 101.8 126.3X
+SQL Parquet MR 275 / 286 3.8 262.3 49.0X
+SQL ORC Vectorized 107 / 115 9.8 101.7 126.4X
+SQL ORC Vectorized with copy 107 / 115 9.8 102.3 125.8X
+SQL ORC MR 1659 / 1664 0.6 1582.3 8.1X
diff --git a/sql/core/benchmarks/WideTableBenchmark-results.txt b/sql/core/benchmarks/WideTableBenchmark-results.txt
index 3b41a3e036c4d..8c09f9ca11307 100644
--- a/sql/core/benchmarks/WideTableBenchmark-results.txt
+++ b/sql/core/benchmarks/WideTableBenchmark-results.txt
@@ -6,12 +6,12 @@ OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
projection on wide table: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
-split threshold 10 38932 / 39307 0.0 37128.1 1.0X
-split threshold 100 31991 / 32556 0.0 30508.8 1.2X
-split threshold 1024 10993 / 11041 0.1 10483.5 3.5X
-split threshold 2048 8959 / 8998 0.1 8543.8 4.3X
-split threshold 4096 8116 / 8134 0.1 7739.8 4.8X
-split threshold 8196 8069 / 8098 0.1 7695.5 4.8X
-split threshold 65536 57068 / 57339 0.0 54424.3 0.7X
+split threshold 10 40571 / 40937 0.0 38691.7 1.0X
+split threshold 100 31116 / 31669 0.0 29674.6 1.3X
+split threshold 1024 10077 / 10199 0.1 9609.7 4.0X
+split threshold 2048 8654 / 8692 0.1 8253.2 4.7X
+split threshold 4096 8006 / 8038 0.1 7634.7 5.1X
+split threshold 8192 8069 / 8107 0.1 7695.3 5.0X
+split threshold 65536 56973 / 57204 0.0 54333.7 0.7X
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 95e98c5444721..ac5f1fc923e7d 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -21,12 +21,12 @@
4.0.0org.apache.spark
- spark-parent_2.11
+ spark-parent_2.123.0.0-SNAPSHOT../../pom.xml
- spark-sql_2.11
+ spark-sql_2.12jarSpark Project SQLhttp://spark.apache.org/
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
index 802949c0ddb60..d4e1d89491f43 100644
--- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
+++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
@@ -20,8 +20,8 @@
import java.io.Serializable;
import java.util.Iterator;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.annotation.Experimental;
-import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.streaming.GroupState;
/**
@@ -33,7 +33,7 @@
* @since 2.1.1
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
public interface FlatMapGroupsWithStateFunction extends Serializable {
Iterator call(K key, Iterator values, GroupState state) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
index 353e9886a8a57..f0abfde843cc5 100644
--- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
+++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
@@ -20,8 +20,8 @@
import java.io.Serializable;
import java.util.Iterator;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.annotation.Experimental;
-import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.streaming.GroupState;
/**
@@ -32,7 +32,7 @@
* @since 2.1.1
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
public interface MapGroupsWithStateFunction extends Serializable {
R call(K key, Iterator values, GroupState state) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java
index 1c3c9794fb6bb..9cc073f53a3eb 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java
@@ -16,14 +16,14 @@
*/
package org.apache.spark.sql;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* SaveMode is used to specify the expected behavior of saving a DataFrame to a data source.
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
public enum SaveMode {
/**
* Append mode means that when saving a DataFrame to a data source, if data/table already exists,
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java
index 4eeb7be3f5abb..631d6eb1cfb03 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF0.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 0 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF0 extends Serializable {
R call() throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java
index 1460daf27dc20..a5d01406edd8c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 1 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF1 extends Serializable {
R call(T1 t1) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java
index 7c4f1e4897084..effe99e30b2a5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 10 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF10 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java
index 26a05106aebd6..e70b18b84b08f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 11 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF11 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java
index 8ef7a99042025..339feb34135e1 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 12 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF12 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java
index 5c3b2ec1222e2..d346e5c908c6f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 13 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF13 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java
index 97e744d843466..d27f9f5270f4b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 14 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF14 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java
index 7ddbf914fc11a..b99b57a91d465 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 15 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF15 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java
index 0ae5dc7195ad6..7899fc4b7ad65 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 16 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF16 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java
index 03543a556c614..40a7e95724fc2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 17 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF17 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java
index 46740d3443916..47935a935891c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 18 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF18 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java
index 33fefd8ecaf1d..578b796ff03a3 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 19 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF19 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java
index 9822f19217d76..2f856aa3cf630 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 2 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF2 extends Serializable {
R call(T1 t1, T2 t2) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java
index 8c5e90182da1c..aa8a9fa897040 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 20 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF20 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java
index e3b09f5167cff..0fe52bce2eca2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 21 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF21 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java
index dc6cfa9097bab..69fd8ca422833 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 22 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF22 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java
index 7c264b69ba195..84ffd655672a2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 3 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF3 extends Serializable {
R call(T1 t1, T2 t2, T3 t3) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java
index 58df38fc3c911..dd2dc285c226d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 4 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF4 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java
index 4146f96e2eed5..795cc21c3f76e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 5 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF5 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java
index 25d39654c1095..a954684c3c9a9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 6 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF6 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java
index ce63b6a91adbb..03761f2c9ebbf 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 7 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF7 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java
index 0e00209ef6b9f..8cd3583b2cbf0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 8 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF8 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java
index 077981bb3e3ee..78a7097791963 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java
@@ -19,12 +19,12 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Stable;
/**
* A Spark SQL UDF that has 9 arguments.
*/
-@InterfaceStability.Stable
+@Stable
public interface UDF9 extends Serializable {
R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java
index 82a1169cbe7ae..7d1fbe64fc960 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/SchemaColumnConvertNotSupportedException.java
@@ -17,12 +17,12 @@
package org.apache.spark.sql.execution.datasources;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Unstable;
/**
* Exception thrown when the parquet reader find column type mismatches.
*/
-@InterfaceStability.Unstable
+@Unstable
public class SchemaColumnConvertNotSupportedException extends RuntimeException {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index b0e119d658cb4..4f5e72c1326ac 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -101,10 +101,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
String message = "Cannot reserve additional contiguous bytes in the vectorized reader (" +
(requiredCapacity >= 0 ? "requested " + requiredCapacity + " bytes" : "integer overflow") +
"). As a workaround, you can reduce the vectorized reader batch size, or disable the " +
- "vectorized reader. For parquet file format, refer to " +
+ "vectorized reader, or disable " + SQLConf.BUCKETING_ENABLED().key() + " if you read " +
+ "from bucket table. For Parquet file format, refer to " +
SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().key() +
" (default " + SQLConf.PARQUET_VECTORIZED_READER_BATCH_SIZE().defaultValueString() +
- ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for orc file format, " +
+ ") and " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + "; for ORC file format, " +
"refer to " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().key() +
" (default " + SQLConf.ORC_VECTORIZED_READER_BATCH_SIZE().defaultValueString() +
") and " + SQLConf.ORC_VECTORIZED_READER_ENABLED().key() + ".";
diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java
index ec9c107b1c119..5a72f0c6a2555 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java
@@ -17,8 +17,8 @@
package org.apache.spark.sql.expressions.javalang;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.annotation.Experimental;
-import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.execution.aggregate.TypedAverage;
@@ -35,7 +35,7 @@
* @since 2.0.0
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
public class typed {
// Note: make sure to keep in sync with typed.scala
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java
index f403dc619e86c..2a4933d75e8d0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchReadSupportProvider.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils;
import org.apache.spark.sql.sources.v2.reader.BatchReadSupport;
import org.apache.spark.sql.types.StructType;
@@ -29,7 +29,7 @@
* This interface is used to create {@link BatchReadSupport} instances when end users run
* {@code SparkSession.read.format(...).option(...).load()}.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface BatchReadSupportProvider extends DataSourceV2 {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java
index bd10c3353bf12..df439e2c02fe3 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/BatchWriteSupportProvider.java
@@ -19,7 +19,7 @@
import java.util.Optional;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.sources.v2.writer.BatchWriteSupport;
import org.apache.spark.sql.types.StructType;
@@ -31,7 +31,7 @@
* This interface is used to create {@link BatchWriteSupport} instances when end users run
* {@code Dataset.write.format(...).option(...).save()}.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface BatchWriteSupportProvider extends DataSourceV2 {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java
index 824c290518acf..b4f2eb34a1560 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ContinuousReadSupportProvider.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils;
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport;
import org.apache.spark.sql.types.StructType;
@@ -29,7 +29,7 @@
* This interface is used to create {@link ContinuousReadSupport} instances when end users run
* {@code SparkSession.readStream.format(...).option(...).load()} with a continuous trigger.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface ContinuousReadSupportProvider extends DataSourceV2 {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java
index 83df3be747085..1c5e3a0cd31e7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java
@@ -26,7 +26,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* An immutable string-to-string map in which keys are case-insensitive. This is used to represent
@@ -73,7 +73,7 @@
*
*
*/
-@InterfaceStability.Evolving
+@Evolving
public class DataSourceOptions {
private final Map keyLowerCasedMap;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java
index 6e31e84bf6c72..eae7a45d1d446 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* The base interface for data source v2. Implementations must have a public, 0-arg constructor.
@@ -30,5 +30,5 @@
* If Spark fails to execute any methods in the implementations of this interface (by throwing an
* exception), the read action will fail and no Spark job will be submitted.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface DataSourceV2 {}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java
index 61c08e7fa89df..c4d9ef88f607e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/MicroBatchReadSupportProvider.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils;
import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport;
import org.apache.spark.sql.types.StructType;
@@ -29,7 +29,7 @@
* This interface is used to create {@link MicroBatchReadSupport} instances when end users run
* {@code SparkSession.readStream.format(...).option(...).load()} with a micro-batch trigger.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface MicroBatchReadSupportProvider extends DataSourceV2 {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java
index bbe430e299261..c00abd9b685b5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java
@@ -17,14 +17,14 @@
package org.apache.spark.sql.sources.v2;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
* propagate session configs with the specified key-prefix to all data source operations in this
* session.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface SessionConfigSupport extends DataSourceV2 {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java
index f9ca85d8089b4..8ac9c51750865 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.execution.streaming.BaseStreamingSink;
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport;
import org.apache.spark.sql.streaming.OutputMode;
@@ -30,7 +30,7 @@
* This interface is used to create {@link StreamingWriteSupport} instances when end users run
* {@code Dataset.writeStream.format(...).option(...).start()}.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java
index 452ee86675b42..518a8b03a2c6e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/BatchReadSupport.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* An interface that defines how to load the data from data source for batch processing.
@@ -29,7 +29,7 @@
* {@link ScanConfig}. The {@link ScanConfig} will be used to create input partitions and reader
* factory to scan data from the data source with a Spark job.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface BatchReadSupport extends ReadSupport {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
index 95c30de907e44..5f5248084bad6 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
@@ -19,7 +19,7 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* A serializable representation of an input partition returned by
@@ -32,7 +32,7 @@
* the actual reading. So {@link InputPartition} must be serializable while {@link PartitionReader}
* doesn't need to be.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface InputPartition extends Serializable {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java
index 04ff8d0a19fc3..2945925959538 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java
@@ -20,7 +20,7 @@
import java.io.Closeable;
import java.io.IOException;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* A partition reader returned by {@link PartitionReaderFactory#createReader(InputPartition)} or
@@ -32,7 +32,7 @@
* data sources(whose {@link PartitionReaderFactory#supportColumnarReads(InputPartition)}
* returns true).
*/
-@InterfaceStability.Evolving
+@Evolving
public interface PartitionReader extends Closeable {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java
index f35de9310eee3..97f4a473953fc 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java
@@ -19,7 +19,7 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.vectorized.ColumnarBatch;
@@ -30,7 +30,7 @@
* {@link PartitionReader} (by throwing an exception), corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface PartitionReaderFactory extends Serializable {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java
index a58ddb288f1ed..b1f610a82e8a2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadSupport.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.types.StructType;
/**
@@ -27,7 +27,7 @@
* If Spark fails to execute any methods in the implementations of this interface (by throwing an
* exception), the read action will fail and no Spark job will be submitted.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface ReadSupport {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java
index 7462ce2820585..a69872a527746 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfig.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.types.StructType;
/**
@@ -31,7 +31,7 @@
* {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}, implementations mostly need to
* cast the input {@link ScanConfig} to the concrete {@link ScanConfig} class of the data source.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface ScanConfig {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java
index 4c0eedfddfe22..4922962f70655 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanConfigBuilder.java
@@ -17,14 +17,14 @@
package org.apache.spark.sql.sources.v2.reader;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* An interface for building the {@link ScanConfig}. Implementations can mixin those
* SupportsPushDownXYZ interfaces to do operator pushdown, and keep the operator pushdown result in
* the returned {@link ScanConfig}.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface ScanConfigBuilder {
ScanConfig build();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java
index 44799c7d49137..14776f37fed46 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java
@@ -19,13 +19,13 @@
import java.util.OptionalLong;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* An interface to represent statistics for a data source, which is returned by
* {@link SupportsReportStatistics#estimateStatistics(ScanConfig)}.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface Statistics {
OptionalLong sizeInBytes();
OptionalLong numRows();
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
index 5e7985f645a06..3a89baa1b44c2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
@@ -17,14 +17,14 @@
package org.apache.spark.sql.sources.v2.reader;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.Filter;
/**
* A mix-in interface for {@link ScanConfigBuilder}. Data sources can implement this interface to
* push down filters to the data source and reduce the size of the data to be read.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface SupportsPushDownFilters extends ScanConfigBuilder {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java
index edb164937d6ef..1934763224881 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.types.StructType;
/**
@@ -25,7 +25,7 @@
* interface to push down required columns to the data source and only read these columns during
* scan to reduce the size of the data to be read.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface SupportsPushDownRequiredColumns extends ScanConfigBuilder {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
index db62cd4515362..0335c7775c2af 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning;
/**
@@ -27,7 +27,7 @@
* Note that, when a {@link ReadSupport} implementation creates exactly one {@link InputPartition},
* Spark may avoid adding a shuffle even if the reader does not implement this interface.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface SupportsReportPartitioning extends ReadSupport {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
index 1831488ba096f..917372cdd25b3 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* A mix in interface for {@link BatchReadSupport}. Data sources can implement this interface to
@@ -27,7 +27,7 @@
* data source. Implementations that return more accurate statistics based on pushed operators will
* not improve query performance until the planner can push operators before getting stats.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface SupportsReportStatistics extends ReadSupport {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
index 6764d4b7665c7..1cdc02f5736b1 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader.partitioning;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.v2.reader.PartitionReader;
/**
@@ -25,7 +25,7 @@
* share the same values for the {@link #clusteredColumns} will be produced by the same
* {@link PartitionReader}.
*/
-@InterfaceStability.Evolving
+@Evolving
public class ClusteredDistribution implements Distribution {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
index 364a3f553923c..02b0e68974919 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader.partitioning;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.v2.reader.PartitionReader;
/**
@@ -37,5 +37,5 @@
*
{@link ClusteredDistribution}
*
*/
-@InterfaceStability.Evolving
+@Evolving
public interface Distribution {}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
index fb0b6f1df43bb..c9a00262c1287 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader.partitioning;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.ScanConfig;
import org.apache.spark.sql.sources.v2.reader.SupportsReportPartitioning;
@@ -28,7 +28,7 @@
* like a snapshot. Once created, it should be deterministic and always report the same number of
* partitions and the same "satisfy" result for a certain distribution.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface Partitioning {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java
index 9101c8a44d34e..c7f6fce6e81af 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java
@@ -17,13 +17,13 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.v2.reader.PartitionReader;
/**
* A variation on {@link PartitionReader} for use with continuous streaming processing.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface ContinuousPartitionReader extends PartitionReader {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java
index 2d9f1ca1686a1..41195befe5e57 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory;
@@ -28,7 +28,7 @@
* instead of {@link org.apache.spark.sql.sources.v2.reader.PartitionReader}. It's used for
* continuous streaming processing.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory {
@Override
ContinuousPartitionReader createReader(InputPartition partition);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java
index 9a3ad2eb8a801..2b784ac0e9f35 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReadSupport.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.execution.streaming.BaseStreamingSource;
import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.ScanConfig;
@@ -36,7 +36,7 @@
* {@link #stop()} will be called when the streaming execution is completed. Note that a single
* query may have multiple executions due to restart or failure recovery.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface ContinuousReadSupport extends StreamingReadSupport, BaseStreamingSource {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java
index edb0db11bff2c..f56066c639388 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchReadSupport.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.execution.streaming.BaseStreamingSource;
import org.apache.spark.sql.sources.v2.reader.*;
@@ -33,7 +33,7 @@
* will be called when the streaming execution is completed. Note that a single query may have
* multiple executions due to restart or failure recovery.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface MicroBatchReadSupport extends StreamingReadSupport, BaseStreamingSource {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
index 6cf27734867cb..6104175d2c9e3 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* An abstract representation of progress through a {@link MicroBatchReadSupport} or
@@ -30,7 +30,7 @@
* maintain compatibility with DataSource V1 APIs. This extension will be removed once we
* get rid of V1 completely.
*/
-@InterfaceStability.Evolving
+@Evolving
public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset {
/**
* A JSON-serialized representation of an Offset that is
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java
index 383e73db6762b..2c97d924a0629 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java
@@ -19,7 +19,7 @@
import java.io.Serializable;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* Used for per-partition offsets in continuous processing. ContinuousReader implementations will
@@ -27,6 +27,6 @@
*
* These offsets must be serializable.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface PartitionOffset extends Serializable {
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java
index 0ec9e05d6a02b..efe1ac4f78db1 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWriteSupport.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.writer;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* An interface that defines how to write the data to data source for batch processing.
@@ -37,7 +37,7 @@
*
* Please refer to the documentation of commit/abort methods for detailed specifications.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface BatchWriteSupport {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
index 5fb067966ee67..d142ee523ef9f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
@@ -19,7 +19,7 @@
import java.io.IOException;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
/**
* A data writer returned by {@link DataWriterFactory#createWriter(int, long)} and is
@@ -55,7 +55,7 @@
*
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface DataWriter {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
index 19a36dd232456..65105f46b82d5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
@@ -20,7 +20,7 @@
import java.io.Serializable;
import org.apache.spark.TaskContext;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.InternalRow;
/**
@@ -31,7 +31,7 @@
* will be created on executors and do the actual writing. So this interface must be
* serializable and {@link DataWriter} doesn't need to be.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface DataWriterFactory extends Serializable {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java
index 123335c414e9f..9216e34399092 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java
@@ -19,8 +19,8 @@
import java.io.Serializable;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport;
-import org.apache.spark.annotation.InterfaceStability;
/**
* A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side
@@ -30,5 +30,5 @@
* This is an empty interface, data sources should define their own message class and use it when
* generating messages at executor side and handling the messages at driver side.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface WriterCommitMessage extends Serializable {}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java
index a4da24fc5ae68..7d3d21cb2b637 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java
@@ -20,7 +20,7 @@
import java.io.Serializable;
import org.apache.spark.TaskContext;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.sources.v2.writer.DataWriter;
@@ -33,7 +33,7 @@
* will be created on executors and do the actual writing. So this interface must be
* serializable and {@link DataWriter} doesn't need to be.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface StreamingDataWriterFactory extends Serializable {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java
index 3fdfac5e1c84a..84cfbf2dda483 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources.v2.writer.streaming;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.sources.v2.writer.DataWriter;
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage;
@@ -27,7 +27,7 @@
* Streaming queries are divided into intervals of data called epochs, with a monotonically
* increasing numeric ID. This writer handles commits and aborts for each successive epoch.
*/
-@InterfaceStability.Evolving
+@Evolving
public interface StreamingWriteSupport {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java
index 5371a23230c98..fd6f7be2abc5a 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java
@@ -19,9 +19,9 @@
import java.util.concurrent.TimeUnit;
+import org.apache.spark.annotation.Evolving;
import scala.concurrent.duration.Duration;
-import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger;
import org.apache.spark.sql.execution.streaming.OneTimeTrigger$;
@@ -30,7 +30,7 @@
*
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
public class Trigger {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index 5f58b031f6aef..906e9bc26ef53 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -22,7 +22,7 @@
import org.apache.arrow.vector.complex.*;
import org.apache.arrow.vector.holders.NullableVarCharHolder;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.execution.arrow.ArrowUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.UTF8String;
@@ -31,7 +31,7 @@
* A column vector backed by Apache Arrow. Currently calendar interval type and map type are not
* supported.
*/
-@InterfaceStability.Evolving
+@Evolving
public final class ArrowColumnVector extends ColumnVector {
private final ArrowVectorAccessor accessor;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
index ad99b450a4809..14caaeaedbe2b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.vectorized;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.types.CalendarInterval;
@@ -47,7 +47,7 @@
* format. Since it is expected to reuse the ColumnVector instance while loading data, the storage
* footprint is negligible.
*/
-@InterfaceStability.Evolving
+@Evolving
public abstract class ColumnVector implements AutoCloseable {
/**
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
index 72a192d089b9f..dd2bd789c26d0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.vectorized;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
@@ -25,7 +25,7 @@
/**
* Array abstraction in {@link ColumnVector}.
*/
-@InterfaceStability.Evolving
+@Evolving
public final class ColumnarArray extends ArrayData {
// The data for this array. This array contains elements from
// data[offset] to data[offset + length).
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
index d206c1df42abb..07546a54013ec 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
@@ -18,7 +18,7 @@
import java.util.*;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.MutableColumnarRow;
@@ -27,7 +27,7 @@
* batch so that Spark can access the data row by row. Instance of it is meant to be reused during
* the entire data loading process.
*/
-@InterfaceStability.Evolving
+@Evolving
public final class ColumnarBatch {
private int numRows;
private final ColumnVector[] columns;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
index f2f2279590023..4b9d3c5f59915 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.vectorized;
-import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.types.*;
@@ -26,7 +26,7 @@
/**
* Row abstraction in {@link ColumnVector}.
*/
-@InterfaceStability.Evolving
+@Evolving
public final class ColumnarRow extends InternalRow {
// The data for this row.
// E.g. the value of 3rd int field is `data.getChild(3).getInt(rowId)`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index a9a19aa8a1001..5a408b29f9337 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
import scala.language.implicitConversions
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
@@ -60,7 +60,7 @@ private[sql] object Column {
*
* @since 1.6.0
*/
-@InterfaceStability.Stable
+@Stable
class TypedColumn[-T, U](
expr: Expression,
private[sql] val encoder: ExpressionEncoder[U])
@@ -74,6 +74,9 @@ class TypedColumn[-T, U](
inputEncoder: ExpressionEncoder[_],
inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)
+
+ // This only inserts inputs into typed aggregate expressions. For untyped aggregate expressions,
+ // the resolving is handled in the analyzer directly.
val newExpr = expr transform {
case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
ta.withInputInfo(
@@ -127,7 +130,7 @@ class TypedColumn[-T, U](
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class Column(val expr: Expression) extends Logging {
def this(name: String) = this(name match {
@@ -1224,7 +1227,7 @@ class Column(val expr: Expression) extends Logging {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class ColumnName(name: String) extends Column(name) {
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 5288907b7d7ff..53e9f810d7c85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -22,18 +22,17 @@ import java.util.Locale
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
-
/**
* Functionality for working with missing data in `DataFrame`s.
*
* @since 1.3.1
*/
-@InterfaceStability.Stable
+@Stable
final class DataFrameNaFunctions private[sql](df: DataFrame) {
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 95c97e5c9433c..da88598eed061 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -24,11 +24,12 @@ import scala.collection.JavaConverters._
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.spark.Partition
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser}
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.FailureSafeParser
import org.apache.spark.sql.execution.command.DDLUtils
@@ -48,7 +49,7 @@ import org.apache.spark.unsafe.types.UTF8String
*
* @since 1.4.0
*/
-@InterfaceStability.Stable
+@Stable
class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
/**
@@ -194,7 +195,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
- val ds = cls.newInstance().asInstanceOf[DataSourceV2]
+ val ds = cls.getConstructor().newInstance().asInstanceOf[DataSourceV2]
if (ds.isInstanceOf[BatchReadSupportProvider]) {
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
ds = ds, conf = sparkSession.sessionState.conf)
@@ -384,6 +385,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* for schema inferring.
*
`dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
* empty array/struct during schema inference.
+ *
`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
+ * For instance, this is used while parsing dates and timestamps.
*
*
* @since 2.0.0
@@ -440,7 +443,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
}
- verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
+ ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
val actualSchema =
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
@@ -502,7 +505,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
parsedOptions)
}
- verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
+ ExprUtils.verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
val actualSchema =
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
@@ -604,6 +607,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
*
`multiLine` (default `false`): parse one record, which may span multiple lines.
+ *
`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
+ * For instance, this is used while parsing dates and timestamps.
+ *
`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing. Maximum length is 1 character.
*
*
* @since 2.0.0
@@ -761,22 +768,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
}
- /**
- * A convenient function for schema validation in datasources supporting
- * `columnNameOfCorruptRecord` as an option.
- */
- private def verifyColumnNameOfCorruptRecord(
- schema: StructType,
- columnNameOfCorruptRecord: String): Unit = {
- schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
- val f = schema(corruptFieldIndex)
- if (f.dataType != StringType || !f.nullable) {
- throw new AnalysisException(
- "The field for corrupt records must be string type and nullable")
- }
- }
- }
-
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 7c12432d33c33..0b22b898557f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -21,7 +21,7 @@ import java.{lang => jl, util => ju}
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.functions.col
@@ -33,7 +33,7 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
*
* @since 1.4.0
*/
-@InterfaceStability.Stable
+@Stable
final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
@@ -51,7 +51,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
*
* This method implements a variation of the Greenwald-Khanna algorithm (with some speed
* optimizations).
- * The algorithm was first present in
+ * The algorithm was first present in
* Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna.
*
* @param col the name of the numerical column
@@ -218,7 +218,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
- * here, proposed by Karp,
+ * here, proposed by Karp,
* Schenker, and Papadimitriou.
* The `support` should be greater than 1e-4.
*
@@ -265,7 +265,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
- * here, proposed by Karp,
+ * here, proposed by Karp,
* Schenker, and Papadimitriou.
* Uses a `default` support of 1%.
*
@@ -284,7 +284,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
* (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
- * here, proposed by Karp, Schenker,
+ * here, proposed by Karp, Schenker,
* and Papadimitriou.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
@@ -328,7 +328,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
* (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
- * here, proposed by Karp, Schenker,
+ * here, proposed by Karp, Schenker,
* and Papadimitriou.
* Uses a `default` support of 1%.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 5a28870f5d3c2..5a807d3d4b93e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -21,7 +21,7 @@ import java.util.{Locale, Properties, UUID}
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog._
@@ -40,7 +40,7 @@ import org.apache.spark.sql.types.StructType
*
* @since 1.4.0
*/
-@InterfaceStability.Stable
+@Stable
final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
private val df = ds.toDF()
@@ -243,7 +243,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
- val source = cls.newInstance().asInstanceOf[DataSourceV2]
+ val source = cls.getConstructor().newInstance().asInstanceOf[DataSourceV2]
source match {
case provider: BatchWriteSupportProvider =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
@@ -658,6 +658,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* whitespaces from values being written should be skipped.
*
`ignoreTrailingWhiteSpace` (default `true`): a flag indicating defines whether or not
* trailing whitespaces from values being written should be skipped.
+ *
`lineSep` (default `\n`): defines the line separator that should be used for writing.
+ * Maximum length is 1 character.
*
*
* @since 2.0.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f835343d6e067..b10d66dfb1aef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,19 +21,18 @@ import java.io.CharArrayWriter
import scala.collection.JavaConverters._
import scala.language.implicitConversions
-import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
import org.apache.commons.lang3.StringUtils
import org.apache.spark.TaskContext
-import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
+import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.encoders._
@@ -78,6 +77,14 @@ private[sql] object Dataset {
qe.assertAnalyzed()
new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
}
+
+ /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */
+ def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker)
+ : DataFrame = {
+ val qe = new QueryExecution(sparkSession, logicalPlan, tracker)
+ qe.assertAnalyzed()
+ new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
+ }
}
/**
@@ -166,10 +173,10 @@ private[sql] object Dataset {
*
* @since 1.6.0
*/
-@InterfaceStability.Stable
+@Stable
class Dataset[T] private[sql](
@transient val sparkSession: SparkSession,
- @DeveloperApi @InterfaceStability.Unstable @transient val queryExecution: QueryExecution,
+ @DeveloperApi @Unstable @transient val queryExecution: QueryExecution,
encoder: Encoder[T])
extends Serializable {
@@ -426,7 +433,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan)
/**
@@ -544,7 +551,7 @@ class Dataset[T] private[sql](
* @group streaming
* @since 2.0.0
*/
- @InterfaceStability.Evolving
+ @Evolving
def isStreaming: Boolean = logicalPlan.isStreaming
/**
@@ -557,7 +564,7 @@ class Dataset[T] private[sql](
* @since 2.1.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true)
/**
@@ -570,7 +577,7 @@ class Dataset[T] private[sql](
* @since 2.1.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = true)
/**
@@ -583,7 +590,7 @@ class Dataset[T] private[sql](
* @since 2.3.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false)
/**
@@ -596,7 +603,7 @@ class Dataset[T] private[sql](
* @since 2.3.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint(
eager = eager,
reliableCheckpoint = false
@@ -671,7 +678,7 @@ class Dataset[T] private[sql](
* @group streaming
* @since 2.1.0
*/
- @InterfaceStability.Evolving
+ @Evolving
// We only accept an existing column name, not a derived column here as a watermark that is
// defined on a derived column cannot referenced elsewhere in the plan.
def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan {
@@ -1066,7 +1073,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
// Creates a Join node and resolve it first, to get join condition resolved, self-join resolved,
// etc.
@@ -1086,7 +1093,7 @@ class Dataset[T] private[sql](
// Note that we do this before joining them, to enable the join operator to return null for one
// side, in cases like outer-join.
val left = {
- val combined = if (!this.exprEnc.isSerializedAsStruct) {
+ val combined = if (!this.exprEnc.isSerializedAsStructForTopLevel) {
assert(joined.left.output.length == 1)
Alias(joined.left.output.head, "_1")()
} else {
@@ -1096,7 +1103,7 @@ class Dataset[T] private[sql](
}
val right = {
- val combined = if (!other.exprEnc.isSerializedAsStruct) {
+ val combined = if (!other.exprEnc.isSerializedAsStructForTopLevel) {
assert(joined.right.output.length == 1)
Alias(joined.right.output.head, "_2")()
} else {
@@ -1109,14 +1116,14 @@ class Dataset[T] private[sql](
// combine the outputs of each join side.
val conditionExpr = joined.condition.get transformUp {
case a: Attribute if joined.left.outputSet.contains(a) =>
- if (!this.exprEnc.isSerializedAsStruct) {
+ if (!this.exprEnc.isSerializedAsStructForTopLevel) {
left.output.head
} else {
val index = joined.left.output.indexWhere(_.exprId == a.exprId)
GetStructField(left.output.head, index)
}
case a: Attribute if joined.right.outputSet.contains(a) =>
- if (!other.exprEnc.isSerializedAsStruct) {
+ if (!other.exprEnc.isSerializedAsStructForTopLevel) {
right.output.head
} else {
val index = joined.right.output.indexWhere(_.exprId == a.exprId)
@@ -1142,7 +1149,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
joinWith(other, condition, "inner")
}
@@ -1384,12 +1391,12 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
implicit val encoder = c1.encoder
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan)
- if (!encoder.isSerializedAsStruct) {
+ if (!encoder.isSerializedAsStructForTopLevel) {
new Dataset[U1](sparkSession, project, encoder)
} else {
// Flattens inner fields of U1
@@ -1418,7 +1425,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
@@ -1430,7 +1437,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def select[U1, U2, U3](
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
@@ -1445,7 +1452,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def select[U1, U2, U3, U4](
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
@@ -1461,7 +1468,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def select[U1, U2, U3, U4, U5](
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
@@ -1632,7 +1639,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def reduce(func: (T, T) => T): T = withNewRDDExecutionId {
rdd.reduce(func)
}
@@ -1647,7 +1654,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
/**
@@ -1659,7 +1666,7 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
val withGroupingKey = AppendColumns(func, logicalPlan)
val executed = sparkSession.sessionState.executePlan(withGroupingKey)
@@ -1681,7 +1688,7 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
groupByKey(func.call(_))(encoder)
@@ -2497,7 +2504,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def filter(func: T => Boolean): Dataset[T] = {
withTypedPlan(TypedFilter(func, logicalPlan))
}
@@ -2511,7 +2518,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def filter(func: FilterFunction[T]): Dataset[T] = {
withTypedPlan(TypedFilter(func, logicalPlan))
}
@@ -2525,7 +2532,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
MapElements[T, U](func, logicalPlan)
}
@@ -2539,7 +2546,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
implicit val uEnc = encoder
withTypedPlan(MapElements[T, U](func, logicalPlan))
@@ -2554,7 +2561,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
new Dataset[U](
sparkSession,
@@ -2571,7 +2578,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
mapPartitions(func)(encoder)
@@ -2602,7 +2609,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
mapPartitions(_.flatMap(func))
@@ -2616,7 +2623,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
val func: (T) => Iterator[U] = x => f.call(x).asScala
flatMap(func)(encoder)
@@ -2803,6 +2810,12 @@ class Dataset[T] private[sql](
* When no explicit sort order is specified, "ascending nulls first" is assumed.
* Note, the rows are not sorted in each partition of the resulting Dataset.
*
+ *
+ * Note that due to performance reasons this method uses sampling to estimate the ranges.
+ * Hence, the output may not be consistent, since sampling can return different values.
+ * The sample size can be controlled by the config
+ * `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
+ *
* @group typedrel
* @since 2.3.0
*/
@@ -2827,6 +2840,11 @@ class Dataset[T] private[sql](
* When no explicit sort order is specified, "ascending nulls first" is assumed.
* Note, the rows are not sorted in each partition of the resulting Dataset.
*
+ * Note that due to performance reasons this method uses sampling to estimate the ranges.
+ * Hence, the output may not be consistent, since sampling can return different values.
+ * The sample size can be controlled by the config
+ * `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
+ *
* @group typedrel
* @since 2.3.0
*/
@@ -3078,7 +3096,7 @@ class Dataset[T] private[sql](
* @group basic
* @since 2.0.0
*/
- @InterfaceStability.Evolving
+ @Evolving
def writeStream: DataStreamWriter[T] = {
if (!isStreaming) {
logicalPlan.failAnalysis(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
index 08aa1bbe78fae..1c4ffefb897ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
/**
* A container for a [[Dataset]], used for implicit conversions in Scala.
@@ -30,7 +30,7 @@ import org.apache.spark.annotation.InterfaceStability
*
* @since 1.6.0
*/
-@InterfaceStability.Stable
+@Stable
case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) {
// This is declared with parentheses to prevent the Scala compiler from treating
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
index bd8dd6ea3fe0f..302d38cde1430 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.{Experimental, Unstable}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
* @since 1.3.0
*/
@Experimental
-@InterfaceStability.Unstable
+@Unstable
class ExperimentalMethods private[sql]() {
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
index 52b8c839643e7..5c0fe798b1044 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
/**
* The abstract class for writing custom logic to process data generated by a query.
@@ -104,7 +104,7 @@ import org.apache.spark.annotation.InterfaceStability
*
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
abstract class ForeachWriter[T] extends Serializable {
// TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 555bcdffb6ee4..a3cbea9021f22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -19,13 +19,14 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.{Evolving, Experimental}
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode}
/**
@@ -37,7 +38,7 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode
* @since 2.0.0
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
class KeyValueGroupedDataset[K, V] private[sql](
kEncoder: Encoder[K],
vEncoder: Encoder[V],
@@ -237,7 +238,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def mapGroupsWithState[S: Encoder, U: Encoder](
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
@@ -272,7 +273,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def mapGroupsWithState[S: Encoder, U: Encoder](
timeoutConf: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
@@ -309,7 +310,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def mapGroupsWithState[S, U](
func: MapGroupsWithStateFunction[K, V, S, U],
stateEncoder: Encoder[S],
@@ -340,7 +341,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def mapGroupsWithState[S, U](
func: MapGroupsWithStateFunction[K, V, S, U],
stateEncoder: Encoder[S],
@@ -371,7 +372,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: GroupStateTimeout)(
@@ -413,7 +414,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def flatMapGroupsWithState[S, U](
func: FlatMapGroupsWithStateFunction[K, V, S, U],
outputMode: OutputMode,
@@ -457,9 +458,13 @@ class KeyValueGroupedDataset[K, V] private[sql](
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(vExprEnc, dataAttributes).named)
- val keyColumn = if (!kExprEnc.isSerializedAsStruct) {
+ val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) {
assert(groupingAttributes.length == 1)
- groupingAttributes.head
+ if (SQLConf.get.nameNonStructGroupingKeyAsValue) {
+ groupingAttributes.head
+ } else {
+ Alias(groupingAttributes.head, "key")()
+ }
} else {
Alias(CreateStruct(groupingAttributes), "key")()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index d4e75b5ebd405..e85636d82a62c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import scala.collection.JavaConverters._
import scala.language.implicitConversions
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
@@ -45,7 +45,7 @@ import org.apache.spark.sql.types.{NumericType, StructType}
*
* @since 2.0.0
*/
-@InterfaceStability.Stable
+@Stable
class RelationalGroupedDataset protected[sql](
df: DataFrame,
groupingExprs: Seq[Expression],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
index 3c39579149fff..5a554eff02e3d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry}
import org.apache.spark.sql.internal.SQLConf
-
/**
* Runtime configuration interface for Spark. To access this, use `SparkSession.conf`.
*
@@ -29,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf
*
* @since 2.0.0
*/
-@InterfaceStability.Stable
+@Stable
class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) {
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 9982b60fefe60..43f34e6ff4b85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -23,7 +23,7 @@ import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
+import org.apache.spark.annotation._
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.ConfigEntry
@@ -54,7 +54,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager
* @groupname Ungrouped Support functions for language integrated queries
* @since 1.0.0
*/
-@InterfaceStability.Stable
+@Stable
class SQLContext private[sql](val sparkSession: SparkSession)
extends Logging with Serializable {
@@ -86,7 +86,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* that listen for execution metrics.
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def listenerManager: ExecutionListenerManager = sparkSession.listenerManager
/**
@@ -158,7 +158,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
*/
@Experimental
@transient
- @InterfaceStability.Unstable
+ @Unstable
def experimental: ExperimentalMethods = sparkSession.experimental
/**
@@ -244,7 +244,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @since 1.3.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
object implicits extends SQLImplicits with Serializable {
protected override def _sqlContext: SQLContext = self
}
@@ -258,7 +258,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @since 1.3.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
sparkSession.createDataFrame(rdd)
}
@@ -271,7 +271,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @since 1.3.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
sparkSession.createDataFrame(data)
}
@@ -319,7 +319,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @since 1.3.0
*/
@DeveloperApi
- @InterfaceStability.Evolving
+ @Evolving
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
sparkSession.createDataFrame(rowRDD, schema)
}
@@ -363,7 +363,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @group dataset
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
sparkSession.createDataset(data)
}
@@ -401,7 +401,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @group dataset
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
sparkSession.createDataset(data)
}
@@ -428,7 +428,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @since 1.3.0
*/
@DeveloperApi
- @InterfaceStability.Evolving
+ @Evolving
def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
sparkSession.createDataFrame(rowRDD, schema)
}
@@ -443,7 +443,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @since 1.6.0
*/
@DeveloperApi
- @InterfaceStability.Evolving
+ @Evolving
def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
sparkSession.createDataFrame(rows, schema)
}
@@ -507,7 +507,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
*
* @since 2.0.0
*/
- @InterfaceStability.Evolving
+ @Evolving
def readStream: DataStreamReader = sparkSession.readStream
@@ -631,7 +631,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @group dataframe
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def range(end: Long): DataFrame = sparkSession.range(end).toDF()
/**
@@ -643,7 +643,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @group dataframe
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def range(start: Long, end: Long): DataFrame = sparkSession.range(start, end).toDF()
/**
@@ -655,7 +655,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @group dataframe
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def range(start: Long, end: Long, step: Long): DataFrame = {
sparkSession.range(start, end, step).toDF()
}
@@ -670,7 +670,7 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* @group dataframe
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = {
sparkSession.range(start, end, step, numPartitions).toDF()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 05db292bd41b1..d329af0145c2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -21,7 +21,7 @@ import scala.collection.Map
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
*
* @since 1.6.0
*/
-@InterfaceStability.Evolving
+@Evolving
abstract class SQLImplicits extends LowPrioritySQLImplicits {
protected def _sqlContext: SQLContext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 3f0b8208612d7..703272acda22c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -25,7 +25,7 @@ import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
-import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
+import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
@@ -73,7 +73,7 @@ import org.apache.spark.util.{CallSite, Utils}
* @param parentSessionState If supplied, inherit all session state (i.e. temporary
* views, SQL config, UDFs etc) from parent.
*/
-@InterfaceStability.Stable
+@Stable
class SparkSession private(
@transient val sparkContext: SparkContext,
@transient private val existingSharedState: Option[SharedState],
@@ -124,7 +124,7 @@ class SparkSession private(
*
* @since 2.2.0
*/
- @InterfaceStability.Unstable
+ @Unstable
@transient
lazy val sharedState: SharedState = {
existingSharedState.getOrElse(new SharedState(sparkContext))
@@ -145,7 +145,7 @@ class SparkSession private(
*
* @since 2.2.0
*/
- @InterfaceStability.Unstable
+ @Unstable
@transient
lazy val sessionState: SessionState = {
parentSessionState
@@ -186,7 +186,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def listenerManager: ExecutionListenerManager = sessionState.listenerManager
/**
@@ -197,7 +197,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Unstable
+ @Unstable
def experimental: ExperimentalMethods = sessionState.experimentalMethods
/**
@@ -231,7 +231,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Unstable
+ @Unstable
def streams: StreamingQueryManager = sessionState.streamingQueryManager
/**
@@ -289,7 +289,7 @@ class SparkSession private(
* @return 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def emptyDataset[T: Encoder]: Dataset[T] = {
val encoder = implicitly[Encoder[T]]
new Dataset(self, LocalRelation(encoder.schema.toAttributes), encoder)
@@ -302,7 +302,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
SparkSession.setActiveSession(this)
val encoder = Encoders.product[A]
@@ -316,7 +316,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
SparkSession.setActiveSession(this)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
@@ -356,7 +356,7 @@ class SparkSession private(
* @since 2.0.0
*/
@DeveloperApi
- @InterfaceStability.Evolving
+ @Evolving
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
createDataFrame(rowRDD, schema, needsConversion = true)
}
@@ -370,7 +370,7 @@ class SparkSession private(
* @since 2.0.0
*/
@DeveloperApi
- @InterfaceStability.Evolving
+ @Evolving
def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
createDataFrame(rowRDD.rdd, schema)
}
@@ -384,7 +384,7 @@ class SparkSession private(
* @since 2.0.0
*/
@DeveloperApi
- @InterfaceStability.Evolving
+ @Evolving
def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala))
}
@@ -474,7 +474,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
val enc = encoderFor[T]
val attributes = enc.schema.toAttributes
@@ -493,7 +493,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = {
Dataset[T](self, ExternalRDD(data, self))
}
@@ -515,7 +515,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
createDataset(data.asScala)
}
@@ -528,7 +528,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def range(end: Long): Dataset[java.lang.Long] = range(0, end)
/**
@@ -539,7 +539,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def range(start: Long, end: Long): Dataset[java.lang.Long] = {
range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism)
}
@@ -552,7 +552,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
range(start, end, step, numPartitions = sparkContext.defaultParallelism)
}
@@ -566,7 +566,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
new Dataset(self, Range(start, end, step, numPartitions), Encoders.LONG)
}
@@ -648,7 +648,11 @@ class SparkSession private(
* @since 2.0.0
*/
def sql(sqlText: String): DataFrame = {
- Dataset.ofRows(self, sessionState.sqlParser.parsePlan(sqlText))
+ val tracker = new QueryPlanningTracker
+ val plan = tracker.measureTime(QueryPlanningTracker.PARSING) {
+ sessionState.sqlParser.parsePlan(sqlText)
+ }
+ Dataset.ofRows(self, plan, tracker)
}
/**
@@ -672,7 +676,7 @@ class SparkSession private(
*
* @since 2.0.0
*/
- @InterfaceStability.Evolving
+ @Evolving
def readStream: DataStreamReader = new DataStreamReader(self)
/**
@@ -706,7 +710,7 @@ class SparkSession private(
* @since 2.0.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
object implicits extends SQLImplicits with Serializable {
protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext
}
@@ -775,13 +779,13 @@ class SparkSession private(
}
-@InterfaceStability.Stable
+@Stable
object SparkSession extends Logging {
/**
* Builder for [[SparkSession]].
*/
- @InterfaceStability.Stable
+ @Stable
class Builder extends Logging {
private[this] val options = new scala.collection.mutable.HashMap[String, String]
@@ -1146,7 +1150,7 @@ object SparkSession extends Logging {
val extensionConfClassName = extensionOption.get
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
- val extensionConf = extensionConfClass.newInstance()
+ val extensionConf = extensionConfClass.getConstructor().newInstance()
.asInstanceOf[SparkSessionExtensions => Unit]
extensionConf(extensions)
} catch {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index a4864344b2d25..5ed76789786bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql
import scala.collection.mutable
-import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
+import org.apache.spark.annotation.{DeveloperApi, Experimental, Unstable}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
@@ -66,7 +66,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
*/
@DeveloperApi
@Experimental
-@InterfaceStability.Unstable
+@Unstable
class SparkSessionExtensions {
type RuleBuilder = SparkSession => Rule[LogicalPlan]
type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index aa3a6c3bf122f..5a3f556c9c074 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -22,7 +22,7 @@ import java.lang.reflect.ParameterizedType
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.internal.Logging
import org.apache.spark.sql.api.java._
@@ -44,7 +44,7 @@ import org.apache.spark.util.Utils
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging {
protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
@@ -670,7 +670,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class $className")
} else {
try {
- val udf = clazz.newInstance()
+ val udf = clazz.getConstructor().newInstance()
val udfReturnType = udfInterfaces(0).getActualTypeArguments.last
var returnType = returnDataType
if (returnType == null) {
@@ -727,7 +727,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction")
}
- val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction]
+ val udaf = clazz.getConstructor().newInstance().asInstanceOf[UserDefinedAggregateFunction]
register(name, udaf)
} catch {
case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index af20764f9a968..becb05cf72aba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -111,7 +111,7 @@ private[sql] object SQLUtils extends Logging {
private[this] def doConversion(data: Object, dataType: DataType): Object = {
data match {
case d: java.lang.Double if dataType == FloatType =>
- new java.lang.Float(d)
+ java.lang.Float.valueOf(d.toFloat)
// Scala Map is the only allowed external type of map type in Row.
case m: java.util.Map[_, _] => m.asScala
case _ => data
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
index ab81725def3f4..44668610d8052 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalog
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.{Evolving, Experimental, Stable}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.storage.StorageLevel
@@ -29,7 +29,7 @@ import org.apache.spark.storage.StorageLevel
*
* @since 2.0.0
*/
-@InterfaceStability.Stable
+@Stable
abstract class Catalog {
/**
@@ -233,7 +233,7 @@ abstract class Catalog {
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createTable(tableName: String, path: String): DataFrame
/**
@@ -261,7 +261,7 @@ abstract class Catalog {
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createTable(tableName: String, path: String, source: String): DataFrame
/**
@@ -292,7 +292,7 @@ abstract class Catalog {
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createTable(
tableName: String,
source: String,
@@ -330,7 +330,7 @@ abstract class Catalog {
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createTable(
tableName: String,
source: String,
@@ -366,7 +366,7 @@ abstract class Catalog {
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createTable(
tableName: String,
source: String,
@@ -406,7 +406,7 @@ abstract class Catalog {
* @since 2.2.0
*/
@Experimental
- @InterfaceStability.Evolving
+ @Evolving
def createTable(
tableName: String,
source: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala
index c0c5ebc2ba2d6..cb270875228ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalog
import javax.annotation.Nullable
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.DefinedByConstructorParams
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams
* @param locationUri path (in the form of a uri) to data files.
* @since 2.0.0
*/
-@InterfaceStability.Stable
+@Stable
class Database(
val name: String,
@Nullable val description: String,
@@ -61,7 +61,7 @@ class Database(
* @param isTemporary whether the table is a temporary table.
* @since 2.0.0
*/
-@InterfaceStability.Stable
+@Stable
class Table(
val name: String,
@Nullable val database: String,
@@ -93,7 +93,7 @@ class Table(
* @param isBucket whether the column is a bucket column.
* @since 2.0.0
*/
-@InterfaceStability.Stable
+@Stable
class Column(
val name: String,
@Nullable val description: String,
@@ -126,7 +126,7 @@ class Column(
* @param isTemporary whether the function is a temporary function or not.
* @since 2.0.0
*/
-@InterfaceStability.Stable
+@Stable
class Function(
val name: String,
@Nullable val database: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 9b500c1a040ae..9c04f8679b46d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -56,8 +57,8 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
case (key, value) =>
key + ": " + StringUtils.abbreviate(redact(value), 100)
}
- val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "")
- s"$nodeNamePrefix$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr"
+ val metadataStr = truncatedString(metadataEntries, " ", ", ", "")
+ s"$nodeNamePrefix$nodeName${truncatedString(output, "[", ",", "]")}$metadataStr"
}
override def verboseString: String = redact(super.verboseString)
@@ -83,7 +84,7 @@ case class RowDataSourceScanExec(
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
override val tableIdentifier: Option[TableIdentifier])
- extends DataSourceScanExec {
+ extends DataSourceScanExec with InputRDDCodegen {
def output: Seq[Attribute] = requiredColumnsIndex.map(fullOutput)
@@ -103,30 +104,10 @@ case class RowDataSourceScanExec(
}
}
- override def inputRDDs(): Seq[RDD[InternalRow]] = {
- rdd :: Nil
- }
+ // Input can be InternalRow, has to be turned into UnsafeRows.
+ override protected val createUnsafeProjection: Boolean = true
- override protected def doProduce(ctx: CodegenContext): String = {
- val numOutputRows = metricTerm(ctx, "numOutputRows")
- // PhysicalRDD always just has one input
- val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")
- val exprRows = output.zipWithIndex.map{ case (a, i) =>
- BoundReference(i, a.dataType, a.nullable)
- }
- val row = ctx.freshName("row")
- ctx.INPUT_ROW = row
- ctx.currentVars = null
- val columnsRowInput = exprRows.map(_.genCode(ctx))
- s"""
- |while ($input.hasNext()) {
- | InternalRow $row = (InternalRow) $input.next();
- | $numOutputRows.add(1);
- | ${consume(ctx, columnsRowInput).trim}
- | if (shouldStop()) return;
- |}
- """.stripMargin
- }
+ override def inputRDD: RDD[InternalRow] = rdd
override val metadata: Map[String, String] = {
val markedFilters = for (filter <- filters) yield {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 2962becb64e88..e214bfd050410 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -24,9 +24,9 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.Utils
object RDDConversions {
def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
@@ -175,7 +175,7 @@ case class RDDScanExec(
rdd: RDD[InternalRow],
name: String,
override val outputPartitioning: Partitioning = UnknownPartitioning(0),
- override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode {
+ override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode with InputRDDCodegen {
private def rddName: String = Option(rdd.name).map(n => s" $n").getOrElse("")
@@ -197,6 +197,11 @@ case class RDDScanExec(
}
override def simpleString: String = {
- s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}"
+ s"$nodeName${truncatedString(output, "[", ",", "]")}"
}
+
+ // Input can be InternalRow, has to be turned into UnsafeRows.
+ override protected val createUnsafeProjection: Boolean = true
+
+ override def inputRDD: RDD[InternalRow] = rdd
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
index 448eb703eacde..31640db3722ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
*/
case class LocalTableScanExec(
output: Seq[Attribute],
- @transient rows: Seq[InternalRow]) extends LeafExecNode {
+ @transient rows: Seq[InternalRow]) extends LeafExecNode with InputRDDCodegen {
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -76,4 +76,12 @@ case class LocalTableScanExec(
longMetric("numOutputRows").add(taken.size)
taken
}
+
+ // Input is already UnsafeRows.
+ override protected val createUnsafeProjection: Boolean = false
+
+ // Do not codegen when there is no parent - to support the fast driver-local collect/take paths.
+ override def supportCodegen: Boolean = (parent != null)
+
+ override def inputRDD: RDD[InternalRow] = rdd
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index c957285b2a315..714581ee10c32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -17,16 +17,21 @@
package org.apache.spark.sql.execution
+import java.io.{BufferedWriter, OutputStreamWriter, Writer}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
+import org.apache.commons.io.output.StringBuilderWriter
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.adaptive.PlanQueryStage
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
@@ -40,7 +45,10 @@ import org.apache.spark.util.Utils
* While this is not a public class, we should avoid changing the function names for the sake of
* changing them, because a lot of developers use the feature for debugging.
*/
-class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
+class QueryExecution(
+ val sparkSession: SparkSession,
+ val logical: LogicalPlan,
+ val tracker: QueryPlanningTracker = new QueryPlanningTracker) {
// TODO: Move the planner an optimizer into here from SessionState.
protected def planner = sparkSession.sessionState.planner
@@ -53,9 +61,9 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
}
}
- lazy val analyzed: LogicalPlan = {
+ lazy val analyzed: LogicalPlan = tracker.measureTime(QueryPlanningTracker.ANALYSIS) {
SparkSession.setActiveSession(sparkSession)
- sparkSession.sessionState.analyzer.executeAndCheck(logical)
+ sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker)
}
lazy val withCachedData: LogicalPlan = {
@@ -64,9 +72,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
sparkSession.sharedState.cacheManager.useCachedData(analyzed)
}
- lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData)
+ lazy val optimizedPlan: LogicalPlan = tracker.measureTime(QueryPlanningTracker.OPTIMIZATION) {
+ sparkSession.sessionState.optimizer.executeAndTrack(withCachedData, tracker)
+ }
- lazy val sparkPlan: SparkPlan = {
+ lazy val sparkPlan: SparkPlan = tracker.measureTime(QueryPlanningTracker.PLANNING) {
SparkSession.setActiveSession(sparkSession)
// TODO: We use next(), i.e. take the first plan returned by the planner, here for now,
// but we will implement to choose the best plan.
@@ -75,7 +85,9 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
- lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
+ lazy val executedPlan: SparkPlan = tracker.measureTime(QueryPlanningTracker.PLANNING) {
+ prepareForExecution(sparkPlan)
+ }
/** Internal version of the RDD. Avoids copies and has no schema */
lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
@@ -203,23 +215,38 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
""".stripMargin.trim
}
+ private def writeOrError(writer: Writer)(f: Writer => Unit): Unit = {
+ try f(writer)
+ catch {
+ case e: AnalysisException => writer.write(e.toString)
+ }
+ }
+
+ private def writePlans(writer: Writer): Unit = {
+ val (verbose, addSuffix) = (true, false)
+
+ writer.write("== Parsed Logical Plan ==\n")
+ writeOrError(writer)(logical.treeString(_, verbose, addSuffix))
+ writer.write("\n== Analyzed Logical Plan ==\n")
+ val analyzedOutput = stringOrError(truncatedString(
+ analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", "))
+ writer.write(analyzedOutput)
+ writer.write("\n")
+ writeOrError(writer)(analyzed.treeString(_, verbose, addSuffix))
+ writer.write("\n== Optimized Logical Plan ==\n")
+ writeOrError(writer)(optimizedPlan.treeString(_, verbose, addSuffix))
+ writer.write("\n== Physical Plan ==\n")
+ writeOrError(writer)(executedPlan.treeString(_, verbose, addSuffix))
+ }
+
override def toString: String = withRedaction {
- def output = Utils.truncatedString(
- analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ")
- val analyzedPlan = Seq(
- stringOrError(output),
- stringOrError(analyzed.treeString(verbose = true))
- ).filter(_.nonEmpty).mkString("\n")
-
- s"""== Parsed Logical Plan ==
- |${stringOrError(logical.treeString(verbose = true))}
- |== Analyzed Logical Plan ==
- |$analyzedPlan
- |== Optimized Logical Plan ==
- |${stringOrError(optimizedPlan.treeString(verbose = true))}
- |== Physical Plan ==
- |${stringOrError(executedPlan.treeString(verbose = true))}
- """.stripMargin.trim
+ val writer = new StringBuilderWriter()
+ try {
+ writePlans(writer)
+ writer.toString
+ } finally {
+ writer.close()
+ }
}
def stringWithStats: String = withRedaction {
@@ -264,5 +291,22 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
def codegenToSeq(): Seq[(String, String)] = {
org.apache.spark.sql.execution.debug.codegenStringSeq(executedPlan)
}
+
+ /**
+ * Dumps debug information about query execution into the specified file.
+ */
+ def toFile(path: String): Unit = {
+ val filePath = new Path(path)
+ val fs = filePath.getFileSystem(sparkSession.sessionState.newHadoopConf())
+ val writer = new BufferedWriter(new OutputStreamWriter(fs.create(filePath)))
+
+ try {
+ writePlans(writer)
+ writer.write("\n== Whole Stage Codegen ==\n")
+ org.apache.spark.sql.execution.debug.writeCodegen(writer, executedPlan)
+ } finally {
+ writer.close()
+ }
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
index 862ee05392f37..9b05faaed0459 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -22,6 +22,7 @@ import java.util.Arrays
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleMetricsReporter}
/**
* The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition
@@ -112,6 +113,7 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A
*/
class ShuffledRowRDD(
var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
+ metrics: Map[String, SQLMetric],
specifiedPartitionStartIndices: Option[Array[Int]] = None)
extends RDD[InternalRow](dependency.rdd.context, Nil) {
@@ -154,6 +156,10 @@ class ShuffledRowRDD(
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition]
+ val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
+ // `SQLShuffleMetricsReporter` will update its own metrics for SQL exchange operator,
+ // as well as the `tempMetrics` for basic shuffle metrics.
+ val sqlMetricsReporter = new SQLShuffleMetricsReporter(tempMetrics, metrics)
// The range of pre-shuffle partitions that we are fetching at here is
// [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1].
val reader =
@@ -161,7 +167,8 @@ class ShuffledRowRDD(
dependency.shuffleHandle,
shuffledRowPartition.startPreShufflePartitionIndex,
shuffledRowPartition.endPreShufflePartitionIndex,
- context)
+ context,
+ sqlMetricsReporter)
reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 5f81b6fe743c9..fbda0d87a175f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import java.io.Writer
import java.util.Locale
import java.util.function.Supplier
@@ -349,6 +350,15 @@ trait CodegenSupport extends SparkPlan {
*/
def needStopCheck: Boolean = parent.needStopCheck
+ /**
+ * Helper default should stop check code.
+ */
+ def shouldStopCheckCode: String = if (needStopCheck) {
+ "if (shouldStop()) return;"
+ } else {
+ "// shouldStop check is eliminated"
+ }
+
/**
* A sequence of checks which evaluate to true if the downstream Limit operators have not received
* enough records and reached the limit. If current node is a data producing node, it can leverage
@@ -405,6 +415,53 @@ trait BlockingOperatorWithCodegen extends CodegenSupport {
override def limitNotReachedChecks: Seq[String] = Nil
}
+/**
+ * Leaf codegen node reading from a single RDD.
+ */
+trait InputRDDCodegen extends CodegenSupport {
+
+ def inputRDD: RDD[InternalRow]
+
+ // If the input can be InternalRows, an UnsafeProjection needs to be created.
+ protected val createUnsafeProjection: Boolean
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ inputRDD :: Nil
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
+ // Inline mutable state since an InputRDDCodegen is used once in a task for WholeStageCodegen
+ val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
+ forceInline = true)
+ val row = ctx.freshName("row")
+
+ val outputVars = if (createUnsafeProjection) {
+ // creating the vars will make the parent consume add an unsafe projection.
+ ctx.INPUT_ROW = row
+ ctx.currentVars = null
+ output.zipWithIndex.map { case (a, i) =>
+ BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ }
+ } else {
+ null
+ }
+
+ val updateNumOutputRowsMetrics = if (metrics.contains("numOutputRows")) {
+ val numOutputRows = metricTerm(ctx, "numOutputRows")
+ s"$numOutputRows.add(1);"
+ } else {
+ ""
+ }
+ s"""
+ | while ($limitNotReachedCond $input.hasNext()) {
+ | InternalRow $row = (InternalRow) $input.next();
+ | ${updateNumOutputRowsMetrics}
+ | ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim}
+ | ${shouldStopCheckCode}
+ | }
+ """.stripMargin
+ }
+}
/**
* InputAdapter is used to hide a SparkPlan from a subtree that supports codegen.
@@ -412,7 +469,7 @@ trait BlockingOperatorWithCodegen extends CodegenSupport {
* This is the leaf node of a tree with WholeStageCodegen that is used to generate code
* that consumes an RDD iterator of InternalRow.
*/
-case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
+case class InputAdapter(child: SparkPlan) extends UnaryExecNode with InputRDDCodegen {
override def output: Seq[Attribute] = child.output
@@ -428,33 +485,19 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
child.doExecuteBroadcast()
}
- override def inputRDDs(): Seq[RDD[InternalRow]] = {
- child.execute() :: Nil
- }
+ override def inputRDD: RDD[InternalRow] = child.execute()
- override def doProduce(ctx: CodegenContext): String = {
- // Right now, InputAdapter is only used when there is one input RDD.
- // Inline mutable state since an InputAdapter is used once in a task for WholeStageCodegen
- val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",
- forceInline = true)
- val row = ctx.freshName("row")
- s"""
- | while ($limitNotReachedCond $input.hasNext()) {
- | InternalRow $row = (InternalRow) $input.next();
- | ${consume(ctx, null, row).trim}
- | if (shouldStop()) return;
- | }
- """.stripMargin
- }
+ // InputAdapter does not need UnsafeProjection.
+ protected val createUnsafeProjection: Boolean = false
override def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
- builder: StringBuilder,
+ writer: Writer,
verbose: Boolean,
prefix: String = "",
- addSuffix: Boolean = false): StringBuilder = {
- child.generateTreeString(depth, lastChildren, builder, verbose, "")
+ addSuffix: Boolean = false): Unit = {
+ child.generateTreeString(depth, lastChildren, writer, verbose, prefix = "", addSuffix = false)
}
override def needCopyResult: Boolean = false
@@ -726,11 +769,11 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
override def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
- builder: StringBuilder,
+ writer: Writer,
verbose: Boolean,
prefix: String = "",
- addSuffix: Boolean = false): StringBuilder = {
- child.generateTreeString(depth, lastChildren, builder, verbose, s"*($codegenStageId) ")
+ addSuffix: Boolean = false): Unit = {
+ child.generateTreeString(depth, lastChildren, writer, verbose, s"*($codegenStageId) ", false)
}
override def needStopCheck: Boolean = true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala
index 20f13c280c12c..33d001629f986 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.adaptive
+import java.io.Writer
import java.util.Properties
import scala.concurrent.{ExecutionContext, Future}
@@ -172,11 +173,11 @@ abstract class QueryStage extends UnaryExecNode {
override def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
- builder: StringBuilder,
+ writer: Writer,
verbose: Boolean,
prefix: String = "",
- addSuffix: Boolean = false): StringBuilder = {
- child.generateTreeString(depth, lastChildren, builder, verbose, "*")
+ addSuffix: Boolean = false): Unit = {
+ child.generateTreeString(depth, lastChildren, writer, verbose, "*")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala
index 887f815b6117a..3a43b1cfe94bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql.execution.adaptive
+import java.io.Writer
+
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* QueryStageInput is the leaf node of a QueryStage and is used to hide its child stage. It gets
@@ -58,11 +61,11 @@ abstract class QueryStageInput extends LeafExecNode {
override def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
- builder: StringBuilder,
+ writer: Writer,
verbose: Boolean,
prefix: String = "",
- addSuffix: Boolean = false): StringBuilder = {
- childStage.generateTreeString(depth, lastChildren, builder, verbose, "*")
+ addSuffix: Boolean = false): Unit = {
+ childStage.generateTreeString(depth, lastChildren, writer, verbose, "*")
}
}
@@ -78,13 +81,15 @@ case class ShuffleQueryStageInput(
partitionStartIndices: Option[Array[Int]] = None)
extends QueryStageInput {
+ override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext)
+
override def outputPartitioning: Partitioning = partitionStartIndices.map {
indices => UnknownPartitioning(indices.length)
}.getOrElse(super.outputPartitioning)
override def doExecute(): RDD[InternalRow] = {
val childRDD = childStage.execute().asInstanceOf[ShuffledRowRDD]
- new ShuffledRowRDD(childRDD.dependency, partitionStartIndices)
+ new ShuffledRowRDD(childRDD.dependency, metrics, partitionStartIndices)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 25d8e7dff3d99..4827f838fc514 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.TaskContext
-import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
@@ -762,6 +763,8 @@ case class HashAggregateExec(
("true", "true", "", "")
}
+ val oomeClassName = classOf[SparkOutOfMemoryError].getName
+
val findOrInsertRegularHashMap: String =
s"""
|// generate grouping key
@@ -787,7 +790,7 @@ case class HashAggregateExec(
| $unsafeRowKeys, ${hashEval.value});
| if ($unsafeRowBuffer == null) {
| // failed to allocate the first page
- | throw new OutOfMemoryError("No enough memory for aggregation");
+ | throw new $oomeClassName("No enough memory for aggregation");
| }
|}
""".stripMargin
@@ -928,9 +931,9 @@ case class HashAggregateExec(
testFallbackStartsAt match {
case None =>
- val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]")
- val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]")
- val outputString = Utils.truncatedString(output, "[", ", ", "]")
+ val keyString = truncatedString(groupingExpressions, "[", ", ", "]")
+ val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]")
+ val outputString = truncatedString(output, "[", ", ", "]")
if (verbose) {
s"HashAggregate(keys=$keyString, functions=$functionString, output=$outputString)"
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
index 66955b8ef723c..7145bb03028d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
@@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.Utils
/**
* A hash-based aggregate operator that supports [[TypedImperativeAggregate]] functions that may
@@ -143,9 +143,9 @@ case class ObjectHashAggregateExec(
private def toString(verbose: Boolean): String = {
val allAggregateExpressions = aggregateExpressions
- val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]")
- val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]")
- val outputString = Utils.truncatedString(output, "[", ", ", "]")
+ val keyString = truncatedString(groupingExpressions, "[", ", ", "]")
+ val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]")
+ val outputString = truncatedString(output, "[", ", ", "]")
if (verbose) {
s"ObjectHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)"
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index fc87de2c52e41..d732b905dcdd5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.Utils
/**
* Sort-based aggregate operator.
@@ -114,9 +114,9 @@ case class SortAggregateExec(
private def toString(verbose: Boolean): String = {
val allAggregateExpressions = aggregateExpressions
- val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]")
- val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]")
- val outputString = Utils.truncatedString(output, "[", ", ", "]")
+ val keyString = truncatedString(groupingExpressions, "[", ", ", "]")
+ val functionString = truncatedString(allAggregateExpressions, "[", ", ", "]")
+ val outputString = truncatedString(output, "[", ", ", "]")
if (verbose) {
s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)"
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 72505f7fac0c6..6d849869b577a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -206,7 +206,9 @@ class TungstenAggregationIterator(
buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
if (buffer == null) {
// failed to allocate the first page
+ // scalastyle:off throwerror
throw new SparkOutOfMemoryError("No enough memory for aggregation")
+ // scalastyle:on throwerror
}
}
processRow(buffer, newInput)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 39200ec00e152..b75752945a492 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -40,9 +40,9 @@ object TypedAggregateExpression {
val outputEncoder = encoderFor[OUT]
val outputType = outputEncoder.objSerializer.dataType
- // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer
- // expression is an alias of `BoundReference`, which means the buffer object doesn't need
- // serialization.
+ // Checks if the buffer object is simple, i.e. the `BUF` type is not serialized as struct
+ // and the serializer expression is an alias of `BoundReference`, which means the buffer
+ // object doesn't need serialization.
val isSimpleBuffer = {
bufferSerializer.head match {
case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true
@@ -76,7 +76,7 @@ object TypedAggregateExpression {
None,
bufferSerializer,
bufferEncoder.resolveAndBind().deserializer,
- outputEncoder.serializer,
+ outputEncoder.objSerializer,
outputType,
outputEncoder.objSerializer.nullable)
}
@@ -213,7 +213,7 @@ case class ComplexTypedAggregateExpression(
inputSchema: Option[StructType],
bufferSerializer: Seq[NamedExpression],
bufferDeserializer: Expression,
- outputSerializer: Seq[Expression],
+ outputSerializer: Expression,
dataType: DataType,
nullable: Boolean,
mutableAggBufferOffset: Int = 0,
@@ -245,13 +245,7 @@ case class ComplexTypedAggregateExpression(
aggregator.merge(buffer, input)
}
- private lazy val resultObjToRow = dataType match {
- case _: StructType =>
- UnsafeProjection.create(CreateStruct(outputSerializer))
- case _ =>
- assert(outputSerializer.length == 1)
- UnsafeProjection.create(outputSerializer.head)
- }
+ private lazy val resultObjToRow = UnsafeProjection.create(outputSerializer)
override def eval(buffer: Any): Any = {
val resultObj = aggregator.finish(buffer)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 3b6588587c35a..73eb65f84489c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -27,9 +27,10 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{LongAccumulator, Utils}
+import org.apache.spark.util.LongAccumulator
/**
@@ -209,5 +210,5 @@ case class InMemoryRelation(
override protected def otherCopyArgs: Seq[AnyRef] = Seq(statsOfPlanToCache)
override def simpleString: String =
- s"InMemoryRelation [${Utils.truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}"
+ s"InMemoryRelation [${truncatedString(output, ", ")}], ${cacheBuilder.storageLevel}"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 823dc0d5ed387..e2cd40906f401 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -231,7 +231,8 @@ case class AlterTableAddColumnsCommand(
}
if (DDLUtils.isDatasourceTable(catalogTable)) {
- DataSource.lookupDataSource(catalogTable.provider.get, conf).newInstance() match {
+ DataSource.lookupDataSource(catalogTable.provider.get, conf).
+ getConstructor().newInstance() match {
// For datasource table, this command can only support the following File format.
// TextFileFormat only default to one column "value"
// Hive type is already considered as hive serde table, so the logic will not
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 55e627135c877..346d1f36ad9b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -206,7 +206,7 @@ case class DataSource(
/** Returns the name and schema of the source that can be used to continually read data. */
private def sourceSchema(): SourceInfo = {
- providingClass.newInstance() match {
+ providingClass.getConstructor().newInstance() match {
case s: StreamSourceProvider =>
val (name, schema) = s.sourceSchema(
sparkSession.sqlContext, userSpecifiedSchema, className, caseInsensitiveOptions)
@@ -252,7 +252,7 @@ case class DataSource(
/** Returns a source that can be used to continually read data. */
def createSource(metadataPath: String): Source = {
- providingClass.newInstance() match {
+ providingClass.getConstructor().newInstance() match {
case s: StreamSourceProvider =>
s.createSource(
sparkSession.sqlContext,
@@ -281,7 +281,7 @@ case class DataSource(
/** Returns a sink that can be used to continually write data. */
def createSink(outputMode: OutputMode): Sink = {
- providingClass.newInstance() match {
+ providingClass.getConstructor().newInstance() match {
case s: StreamSinkProvider =>
s.createSink(sparkSession.sqlContext, caseInsensitiveOptions, partitionColumns, outputMode)
@@ -312,7 +312,7 @@ case class DataSource(
* that files already exist, we don't need to check them again.
*/
def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = {
- val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
+ val relation = (providingClass.getConstructor().newInstance(), userSpecifiedSchema) match {
// TODO: Throw when too much is given.
case (dataSource: SchemaRelationProvider, Some(schema)) =>
dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema)
@@ -481,7 +481,7 @@ case class DataSource(
throw new AnalysisException("Cannot save interval data type into external storage.")
}
- providingClass.newInstance() match {
+ providingClass.getConstructor().newInstance() match {
case dataSource: CreatableRelationProvider =>
dataSource.createRelation(
sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data))
@@ -518,7 +518,7 @@ case class DataSource(
throw new AnalysisException("Cannot save interval data type into external storage.")
}
- providingClass.newInstance() match {
+ providingClass.getConstructor().newInstance() match {
case dataSource: CreatableRelationProvider =>
SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode)
case format: FileFormat =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index fe27b78bf3360..62ab5c80d47cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -150,7 +150,7 @@ object FileSourceStrategy extends Strategy with Logging {
// The attribute name of predicate could be different than the one in schema in case of
// case insensitive, we should change them to match the one in schema, so we do not need to
// worry about case sensitivity anymore.
- val normalizedFilters = filters.map { e =>
+ val normalizedFilters = filters.filterNot(SubqueryExpression.hasSubquery).map { e =>
e transform {
case a: AttributeReference =>
a.withName(l.output.find(_.semanticEquals(a)).get.name)
@@ -163,7 +163,6 @@ object FileSourceStrategy extends Strategy with Logging {
val partitionSet = AttributeSet(partitionColumns)
val partitionKeyFilters =
ExpressionSet(normalizedFilters
- .filterNot(SubqueryExpression.hasSubquery(_))
.filter(_.references.subsetOf(partitionSet)))
logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
index 8d715f6342988..1023572d19e2e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
@@ -21,8 +21,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.sources.BaseRelation
-import org.apache.spark.util.Utils
/**
* Used to link a [[BaseRelation]] in to a logical query plan.
@@ -63,7 +63,7 @@ case class LogicalRelation(
case _ => // Do nothing.
}
- override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation"
+ override def simpleString: String = s"Relation[${truncatedString(output, ",")}] $relation"
}
object LogicalRelation {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
index 16b2367bfdd5c..329b9539f52e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
@@ -42,7 +42,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
// The attribute name of predicate could be different than the one in schema in case of
// case insensitive, we should change them to match the one in schema, so we donot need to
// worry about case sensitivity anymore.
- val normalizedFilters = filters.map { e =>
+ val normalizedFilters = filters.filterNot(SubqueryExpression.hasSubquery).map { e =>
e transform {
case a: AttributeReference =>
a.withName(logicalRelation.output.find(_.semanticEquals(a)).get.name)
@@ -56,7 +56,6 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
val partitionSet = AttributeSet(partitionColumns)
val partitionKeyFilters =
ExpressionSet(normalizedFilters
- .filterNot(SubqueryExpression.hasSubquery(_))
.filter(_.references.subsetOf(partitionSet)))
if (partitionKeyFilters.nonEmpty) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 4808e8ef042d1..b35b8851918b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -95,7 +95,7 @@ object TextInputCSVDataSource extends CSVDataSource {
headerChecker: CSVHeaderChecker,
requiredSchema: StructType): Iterator[InternalRow] = {
val lines = {
- val linesReader = new HadoopFileLinesReader(file, conf)
+ val linesReader = new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => linesReader.close()))
linesReader.map { line =>
new String(line.getBytes, 0, line.getLength, parser.options.charset)
@@ -192,7 +192,8 @@ object MultiLineCSVDataSource extends CSVDataSource {
UnivocityParser.tokenizeStream(
CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, path),
shouldDropHeader = false,
- new CsvParser(parsedOptions.asParserSettings))
+ new CsvParser(parsedOptions.asParserSettings),
+ encoding = parsedOptions.charset)
}.take(1).headOption match {
case Some(firstRow) =>
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
@@ -203,7 +204,8 @@ object MultiLineCSVDataSource extends CSVDataSource {
lines.getConfiguration,
new Path(lines.getPath())),
parsedOptions.headerFlag,
- new CsvParser(parsedOptions.asParserSettings))
+ new CsvParser(parsedOptions.asParserSettings),
+ encoding = parsedOptions.charset)
}
val sampled = CSVUtils.sample(tokenRDD, parsedOptions)
CSVInferSchema.infer(sampled, header, parsedOptions)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 964b56e706a0b..ff1911d69a6b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityGenerator, UnivocityParser}
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
@@ -110,13 +111,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
// Check a field requirement for corrupt records here to throw an exception in a driver side
- dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
- val f = dataSchema(corruptFieldIndex)
- if (f.dataType != StringType || !f.nullable) {
- throw new AnalysisException(
- "The field for corrupt records must be string type and nullable")
- }
- }
+ ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord)
if (requiredSchema.length == 1 &&
requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
index 1723596de1db2..530d836d9fde3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
@@ -50,7 +50,7 @@ object DriverRegistry extends Logging {
} else {
synchronized {
if (wrapperMap.get(className).isEmpty) {
- val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver])
+ val wrapper = new DriverWrapper(cls.getConstructor().newInstance().asInstanceOf[Driver])
DriverManager.registerDriver(wrapper)
wrapperMap(className) = wrapper
logTrace(s"Wrapper for $className registered")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index f15014442e3fb..51c385e25bee3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -27,10 +27,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType}
-import org.apache.spark.util.Utils
/**
* Instructions on how to partition the table among workers.
@@ -159,8 +159,9 @@ private[sql] object JDBCRelation extends Logging {
val column = schema.find { f =>
resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName)
}.getOrElse {
+ val maxNumToStringFields = SQLConf.get.maxToStringFields
throw new AnalysisException(s"User-defined partition column $columnName not " +
- s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}")
+ s"found in the JDBC relation: ${schema.simpleString(maxNumToStringFields)}")
}
column.dataType match {
case _: NumericType | DateType | TimestampType =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 1f7c9d73f19fe..610f0d1619fc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -26,7 +26,8 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead}
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
+import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
@@ -107,13 +108,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
val actualSchema =
StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
// Check a field requirement for corrupt records here to throw an exception in a driver side
- dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
- val f = dataSchema(corruptFieldIndex)
- if (f.dataType != StringType || !f.nullable) {
- throw new AnalysisException(
- "The field for corrupt records must be string type and nullable")
- }
- }
+ ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord)
if (requiredSchema.length == 1 &&
requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
index 84755bfa301f0..7e38fc651a31f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.orc
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.mapreduce.TaskAttemptContext
-import org.apache.orc.mapred.OrcStruct
-import org.apache.orc.mapreduce.OrcOutputFormat
+import org.apache.orc.OrcFile
+import org.apache.orc.mapred.{OrcOutputFormat => OrcMapRedOutputFormat, OrcStruct}
+import org.apache.orc.mapreduce.{OrcMapreduceRecordWriter, OrcOutputFormat}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.OutputWriter
@@ -36,11 +37,17 @@ private[orc] class OrcOutputWriter(
private[this] val serializer = new OrcSerializer(dataSchema)
private val recordWriter = {
- new OrcOutputFormat[OrcStruct]() {
+ val orcOutputFormat = new OrcOutputFormat[OrcStruct]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
new Path(path)
}
- }.getRecordWriter(context)
+ }
+ val filename = orcOutputFormat.getDefaultWorkFile(context, ".orc")
+ val options = OrcMapRedOutputFormat.buildOptions(context.getConfiguration)
+ val writer = OrcFile.createWriter(filename, options)
+ val recordWriter = new OrcMapreduceRecordWriter[OrcStruct](writer)
+ OrcUtils.addSparkVersionMetadata(writer)
+ recordWriter
}
override def write(row: InternalRow): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index 95fb25bf5addb..57d2c56e87b4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -17,18 +17,19 @@
package org.apache.spark.sql.execution.datasources.orc
+import java.nio.charset.StandardCharsets.UTF_8
import java.util.Locale
import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.orc.{OrcFile, Reader, TypeDescription}
+import org.apache.orc.{OrcFile, Reader, TypeDescription, Writer}
-import org.apache.spark.SparkException
+import org.apache.spark.{SPARK_VERSION_SHORT, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession}
import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types._
@@ -144,4 +145,11 @@ object OrcUtils extends Logging {
}
}
}
+
+ /**
+ * Add a metadata specifying Spark version.
+ */
+ def addSparkVersionMetadata(writer: Writer): Unit = {
+ writer.addUserMetadata(SPARK_VERSION_METADATA_KEY, UTF_8.encode(SPARK_VERSION_SHORT))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
index 4096e88ddd123..bca48c33cd523 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
@@ -29,7 +29,9 @@ import org.apache.parquet.hadoop.api.WriteSupport
import org.apache.parquet.hadoop.api.WriteSupport.WriteContext
import org.apache.parquet.io.api.{Binary, RecordConsumer}
+import org.apache.spark.SPARK_VERSION_SHORT
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -92,7 +94,10 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit
this.rootFieldWriters = schema.map(_.dataType).map(makeWriter).toArray[ValueWriter]
val messageType = new SparkToParquetSchemaConverter(configuration).convert(schema)
- val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava
+ val metadata = Map(
+ SPARK_VERSION_METADATA_KEY -> SPARK_VERSION_SHORT,
+ ParquetReadSupport.SPARK_METADATA_KEY -> schemaString
+ ).asJava
logInfo(
s"""Initialized Parquet WriteSupport with Catalyst schema:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala
index 97e6c6d702acb..e829f621b4ea3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.util.Utils
@@ -72,10 +73,10 @@ trait DataSourceV2StringFormat {
}.mkString("[", ",", "]")
}
- val outputStr = Utils.truncatedString(output, "[", ", ", "]")
+ val outputStr = truncatedString(output, "[", ", ", "]")
val entriesStr = if (entries.nonEmpty) {
- Utils.truncatedString(entries.map {
+ truncatedString(entries.map {
case (key, value) => key + ": " + StringUtils.abbreviate(value, 100)
}, " (", ", ", ")")
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 366e1fe6a4aaa..3511cefa7c292 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -17,10 +17,13 @@
package org.apache.spark.sql.execution
+import java.io.Writer
import java.util.Collections
import scala.collection.JavaConverters._
+import org.apache.commons.io.output.StringBuilderWriter
+
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
@@ -30,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, Codegen
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.execution.streaming.{StreamExecution, StreamingQueryWrapper}
-import org.apache.spark.sql.execution.streaming.continuous.WriteToContinuousDataSourceExec
import org.apache.spark.sql.streaming.StreamingQuery
import org.apache.spark.util.{AccumulatorV2, LongAccumulator}
@@ -70,15 +72,25 @@ package object debug {
* @return single String containing all WholeStageCodegen subtrees and corresponding codegen
*/
def codegenString(plan: SparkPlan): String = {
+ val writer = new StringBuilderWriter()
+
+ try {
+ writeCodegen(writer, plan)
+ writer.toString
+ } finally {
+ writer.close()
+ }
+ }
+
+ def writeCodegen(writer: Writer, plan: SparkPlan): Unit = {
val codegenSeq = codegenStringSeq(plan)
- var output = s"Found ${codegenSeq.size} WholeStageCodegen subtrees.\n"
+ writer.write(s"Found ${codegenSeq.size} WholeStageCodegen subtrees.\n")
for (((subtree, code), i) <- codegenSeq.zipWithIndex) {
- output += s"== Subtree ${i + 1} / ${codegenSeq.size} ==\n"
- output += subtree
- output += "\nGenerated code:\n"
- output += s"${code}\n"
+ writer.write(s"== Subtree ${i + 1} / ${codegenSeq.size} ==\n")
+ writer.write(subtree)
+ writer.write("\nGenerated code:\n")
+ writer.write(s"${code}\n")
}
- output
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index bbd1a3f005d74..6387ea27c342a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -26,7 +26,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
@@ -47,7 +47,8 @@ case class ShuffleExchangeExec(
// e.g. it can be null on the Executor side
override lazy val metrics = Map(
- "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"))
+ "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")
+ ) ++ SQLMetrics.getShuffleReadMetrics(sparkContext)
override def nodeName: String = {
"Exchange"
@@ -85,7 +86,7 @@ case class ShuffleExchangeExec(
assert(newPartitioning.isInstanceOf[HashPartitioning])
newPartitioning = UnknownPartitioning(indices.length)
}
- new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
+ new ShuffledRowRDD(shuffleDependency, metrics, specifiedPartitionStartIndices)
}
/**
@@ -198,13 +199,21 @@ object ShuffleExchangeExec {
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}
case RangePartitioning(sortingExpressions, numPartitions) =>
- // Internally, RangePartitioner runs a job on the RDD that samples keys to compute
- // partition bounds. To get accurate samples, we need to copy the mutable keys.
+ // Extract only fields used for sorting to avoid collecting large fields that does not
+ // affect sorting result when deciding partition bounds in RangePartitioner
val rddForSampling = rdd.mapPartitionsInternal { iter =>
+ val projection =
+ UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
val mutablePair = new MutablePair[InternalRow, Null]()
- iter.map(row => mutablePair.update(row.copy(), null))
+ // Internally, RangePartitioner runs a job on the RDD that samples keys to compute
+ // partition bounds. To get accurate samples, we need to copy the mutable keys.
+ iter.map(row => mutablePair.update(projection(row).copy(), null))
}
- implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes)
+ // Construct ordering on extracted sort key.
+ val orderingAttributes = sortingExpressions.zipWithIndex.map { case (ord, i) =>
+ ord.copy(child = BoundReference(i, ord.dataType, ord.nullable))
+ }
+ implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes)
new RangePartitioner(
numPartitions,
rddForSampling,
@@ -230,7 +239,10 @@ object ShuffleExchangeExec {
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
- case RangePartitioning(_, _) | SinglePartition => identity
+ case RangePartitioning(sortingExpressions, _) =>
+ val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
+ row => projection(row)
+ case SinglePartition => identity
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
}
@@ -253,7 +265,7 @@ object ShuffleExchangeExec {
}
// The comparator for comparing row hashcode, which should always be Integer.
val prefixComparator = PrefixComparators.LONG
- val canUseRadixSort = SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED)
+ val canUseRadixSort = SQLConf.get.enableRadixSort
// The prefix computer generates row hashcode as the prefix, so we may decrease the
// probability that the prefixes are equal when input rows choose column values from a
// limited range.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 256f12b605c92..3edfec5e5bb65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering}
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Take the first `limit` elements and collect them to a single partition.
@@ -40,11 +41,13 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
LocalLimitExec(limit, child).executeToIterator().take(limit)
}
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
+ override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext)
protected override def doExecute(): RDD[InternalRow] = {
val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit))
val shuffled = new ShuffledRowRDD(
ShuffleExchangeExec.prepareShuffleDependency(
- locallyLimited, child.output, SinglePartition, serializer))
+ locallyLimited, child.output, SinglePartition, serializer),
+ metrics)
shuffled.mapPartitionsInternal(_.take(limit))
}
}
@@ -154,6 +157,8 @@ case class TakeOrderedAndProjectExec(
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
+ override lazy val metrics = SQLMetrics.getShuffleReadMetrics(sparkContext)
+
protected override def doExecute(): RDD[InternalRow] = {
val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
val localTopK: RDD[InternalRow] = {
@@ -163,7 +168,8 @@ case class TakeOrderedAndProjectExec(
}
val shuffled = new ShuffledRowRDD(
ShuffleExchangeExec.prepareShuffleDependency(
- localTopK, child.output, SinglePartition, serializer))
+ localTopK, child.output, SinglePartition, serializer),
+ metrics)
shuffled.mapPartitions { iter =>
val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
if (projectList != child.output) {
@@ -180,8 +186,8 @@ case class TakeOrderedAndProjectExec(
override def outputPartitioning: Partitioning = SinglePartition
override def simpleString: String = {
- val orderByString = Utils.truncatedString(sortOrder, "[", ",", "]")
- val outputString = Utils.truncatedString(output, "[", ",", "]")
+ val orderByString = truncatedString(sortOrder, "[", ",", "]")
+ val outputString = truncatedString(output, "[", ",", "]")
s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index cbf707f4a9cfd..0b5ee3a5e0577 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -82,6 +82,14 @@ object SQLMetrics {
private val baseForAvgMetric: Int = 10
+ val REMOTE_BLOCKS_FETCHED = "remoteBlocksFetched"
+ val LOCAL_BLOCKS_FETCHED = "localBlocksFetched"
+ val REMOTE_BYTES_READ = "remoteBytesRead"
+ val REMOTE_BYTES_READ_TO_DISK = "remoteBytesReadToDisk"
+ val LOCAL_BYTES_READ = "localBytesRead"
+ val FETCH_WAIT_TIME = "fetchWaitTime"
+ val RECORDS_READ = "recordsRead"
+
/**
* Converts a double value to long value by multiplying a base integer, so we can store it in
* `SQLMetrics`. It only works for average metrics. When showing the metrics on UI, we restore
@@ -194,4 +202,16 @@ object SQLMetrics {
SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value)))
}
}
+
+ /**
+ * Create all shuffle read relative metrics and return the Map.
+ */
+ def getShuffleReadMetrics(sc: SparkContext): Map[String, SQLMetric] = Map(
+ REMOTE_BLOCKS_FETCHED -> createMetric(sc, "remote blocks fetched"),
+ LOCAL_BLOCKS_FETCHED -> createMetric(sc, "local blocks fetched"),
+ REMOTE_BYTES_READ -> createSizeMetric(sc, "remote bytes read"),
+ REMOTE_BYTES_READ_TO_DISK -> createSizeMetric(sc, "remote bytes read to disk"),
+ LOCAL_BYTES_READ -> createSizeMetric(sc, "local bytes read"),
+ FETCH_WAIT_TIME -> createTimingMetric(sc, "fetch wait time"),
+ RECORDS_READ -> createMetric(sc, "records read"))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala
new file mode 100644
index 0000000000000..542141ea4b4e6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.metric
+
+import org.apache.spark.executor.TempShuffleReadMetrics
+
+/**
+ * A shuffle metrics reporter for SQL exchange operators.
+ * @param tempMetrics [[TempShuffleReadMetrics]] created in TaskContext.
+ * @param metrics All metrics in current SparkPlan. This param should not empty and
+ * contains all shuffle metrics defined in [[SQLMetrics.getShuffleReadMetrics]].
+ */
+private[spark] class SQLShuffleMetricsReporter(
+ tempMetrics: TempShuffleReadMetrics,
+ metrics: Map[String, SQLMetric]) extends TempShuffleReadMetrics {
+ private[this] val _remoteBlocksFetched = metrics(SQLMetrics.REMOTE_BLOCKS_FETCHED)
+ private[this] val _localBlocksFetched = metrics(SQLMetrics.LOCAL_BLOCKS_FETCHED)
+ private[this] val _remoteBytesRead = metrics(SQLMetrics.REMOTE_BYTES_READ)
+ private[this] val _remoteBytesReadToDisk = metrics(SQLMetrics.REMOTE_BYTES_READ_TO_DISK)
+ private[this] val _localBytesRead = metrics(SQLMetrics.LOCAL_BYTES_READ)
+ private[this] val _fetchWaitTime = metrics(SQLMetrics.FETCH_WAIT_TIME)
+ private[this] val _recordsRead = metrics(SQLMetrics.RECORDS_READ)
+
+ override def incRemoteBlocksFetched(v: Long): Unit = {
+ _remoteBlocksFetched.add(v)
+ tempMetrics.incRemoteBlocksFetched(v)
+ }
+ override def incLocalBlocksFetched(v: Long): Unit = {
+ _localBlocksFetched.add(v)
+ tempMetrics.incLocalBlocksFetched(v)
+ }
+ override def incRemoteBytesRead(v: Long): Unit = {
+ _remoteBytesRead.add(v)
+ tempMetrics.incRemoteBytesRead(v)
+ }
+ override def incRemoteBytesReadToDisk(v: Long): Unit = {
+ _remoteBytesReadToDisk.add(v)
+ tempMetrics.incRemoteBytesReadToDisk(v)
+ }
+ override def incLocalBytesRead(v: Long): Unit = {
+ _localBytesRead.add(v)
+ tempMetrics.incLocalBytesRead(v)
+ }
+ override def incFetchWaitTime(v: Long): Unit = {
+ _fetchWaitTime.add(v)
+ tempMetrics.incFetchWaitTime(v)
+ }
+ override def incRecordsRead(v: Long): Unit = {
+ _recordsRead.add(v)
+ tempMetrics.incRecordsRead(v)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
index d2820ff335ecf..eb12641f548ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
@@ -23,7 +23,7 @@ import com.google.common.io.Closeables
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.io.NioBufferedFileInputStream
-import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
+import org.apache.spark.memory.{MemoryConsumer, SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.Platform
@@ -226,7 +226,7 @@ private[python] case class HybridRowQueue(
val page = try {
allocatePage(required)
} catch {
- case _: OutOfMemoryError =>
+ case _: SparkOutOfMemoryError =>
null
}
val buffer = if (page != null) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 86f6307254332..420faa6f24734 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -69,7 +69,7 @@ object FrequentItems extends Logging {
/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
- * here, proposed by Karp, Schenker,
+ * here, proposed by Karp, Schenker,
* and Papadimitriou.
* The `support` should be greater than 1e-4.
* For Internal use only.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index bea652cc33076..ac25a8fd90bc7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -45,7 +45,7 @@ object StatFunctions extends Logging {
*
* This method implements a variation of the Greenwald-Khanna algorithm (with some speed
* optimizations).
- * The algorithm was first present in
+ * The algorithm was first present in
* Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna.
*
* @param df the dataframe
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala
index 606ba250ad9d2..b3e4240c315bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala
@@ -56,7 +56,7 @@ trait CheckpointFileManager {
* @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to
* overwrite the file if it already exists. It should not throw
* any exception if the file exists. However, if false, then the
- * implementation must not overwrite if the file alraedy exists and
+ * implementation must not overwrite if the file already exists and
* must throw `FileAlreadyExistsException` in that case.
*/
def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 2cac86599ef19..5defca391a355 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -24,13 +24,14 @@ import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport}
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2}
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
-import org.apache.spark.util.{Clock, Utils}
+import org.apache.spark.util.Clock
class MicroBatchExecution(
sparkSession: SparkSession,
@@ -475,8 +476,8 @@ class MicroBatchExecution(
case StreamingExecutionRelation(source, output) =>
newData.get(source).map { dataPlan =>
assert(output.size == dataPlan.output.size,
- s"Invalid batch: ${Utils.truncatedString(output, ",")} != " +
- s"${Utils.truncatedString(dataPlan.output, ",")}")
+ s"Invalid batch: ${truncatedString(output, ",")} != " +
+ s"${truncatedString(dataPlan.output, ",")}")
val aliases = output.zip(dataPlan.output).map { case (to, from) =>
Alias(from, to.name)(exprId = to.exprId, explicitMetadata = Some(from.metadata))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 631a6eb649ffb..89b4f40c9c0b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -88,6 +88,7 @@ abstract class StreamExecution(
val resolvedCheckpointRoot = {
val checkpointPath = new Path(checkpointRoot)
val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf())
+ fs.mkdirs(checkpointPath)
checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala
index 0bc54eac4ee8e..516afbea5d9de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala
@@ -19,16 +19,18 @@ package org.apache.spark.sql.execution.streaming
import java.io.{InputStreamReader, OutputStreamWriter}
import java.nio.charset.StandardCharsets
+import java.util.ConcurrentModificationException
import scala.util.control.NonFatal
import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, FSDataOutputStream, Path}
+import org.apache.hadoop.fs.{FileAlreadyExistsException, FSDataInputStream, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream
import org.apache.spark.sql.streaming.StreamingQuery
/**
@@ -70,19 +72,26 @@ object StreamMetadata extends Logging {
metadata: StreamMetadata,
metadataFile: Path,
hadoopConf: Configuration): Unit = {
- var output: FSDataOutputStream = null
+ var output: CancellableFSDataOutputStream = null
try {
- val fs = metadataFile.getFileSystem(hadoopConf)
- output = fs.create(metadataFile)
+ val fileManager = CheckpointFileManager.create(metadataFile.getParent, hadoopConf)
+ output = fileManager.createAtomic(metadataFile, overwriteIfPossible = false)
val writer = new OutputStreamWriter(output)
Serialization.write(metadata, writer)
writer.close()
} catch {
- case NonFatal(e) =>
+ case e: FileAlreadyExistsException =>
+ if (output != null) {
+ output.cancel()
+ }
+ throw new ConcurrentModificationException(
+ s"Multiple streaming queries are concurrently using $metadataFile", e)
+ case e: Throwable =>
+ if (output != null) {
+ output.cancel()
+ }
logError(s"Error writing stream metadata $metadata to $metadataFile", e)
throw e
- } finally {
- IOUtils.closeQuietly(output)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
index 19e3e55cb2829..4c0db3cb42a82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.streaming
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.{Evolving, Experimental}
import org.apache.spark.sql.streaming.Trigger
/**
@@ -25,5 +25,5 @@ import org.apache.spark.sql.streaming.Trigger
* the query.
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
case object OneTimeTrigger extends Trigger
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index f009c52449adc..1eab55122e84b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -28,6 +28,7 @@ import org.apache.spark.SparkEnv
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation}
import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _}
@@ -35,7 +36,7 @@ import org.apache.spark.sql.sources.v2
import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider}
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset}
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
-import org.apache.spark.util.{Clock, Utils}
+import org.apache.spark.util.Clock
class ContinuousExecution(
sparkSession: SparkSession,
@@ -164,8 +165,8 @@ class ContinuousExecution(
val newOutput = readSupport.fullSchema().toAttributes
assert(output.size == newOutput.size,
- s"Invalid reader: ${Utils.truncatedString(output, ",")} != " +
- s"${Utils.truncatedString(newOutput, ",")}")
+ s"Invalid reader: ${truncatedString(output, ",")} != " +
+ s"${truncatedString(newOutput, ",")}")
replacements ++= output.zip(newOutput)
val loggedOffset = offsets.offsets(0)
@@ -262,7 +263,12 @@ class ContinuousExecution(
reportTimeTaken("runContinuous") {
SQLExecution.withNewExecutionId(
- sparkSessionForQuery, lastExecution)(lastExecution.toRdd)
+ sparkSessionForQuery, lastExecution) {
+ // Materialize `executedPlan` so that accessing it when `toRdd` is running doesn't need to
+ // wait for a lock
+ lastExecution.executedPlan
+ lastExecution.toRdd
+ }
}
} catch {
case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) &&
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala
index 90e1766c4d9f1..caffcc3c4c1a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTrigger.scala
@@ -23,15 +23,15 @@ import scala.concurrent.duration.Duration
import org.apache.commons.lang3.StringUtils
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.streaming.{ProcessingTime, Trigger}
+import org.apache.spark.annotation.Evolving
+import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.unsafe.types.CalendarInterval
/**
* A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at
* the specified interval.
*/
-@InterfaceStability.Evolving
+@Evolving
case class ContinuousTrigger(intervalMs: Long) extends Trigger {
require(intervalMs >= 0, "the interval of trigger should not be negative")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index adf52aba21a04..daee089f3871d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -31,11 +31,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
object MemoryStream {
protected val currentBlockId = new AtomicInteger(0)
@@ -117,7 +117,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
}
- override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]"
+ override def toString: String = s"MemoryStream[${truncatedString(output, ",")}]"
override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index d3313b8a315c9..7d785aa09cd9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -213,7 +213,7 @@ object StateStoreProvider {
*/
def create(providerClassName: String): StateStoreProvider = {
val providerClass = Utils.classForName(providerClassName)
- providerClass.newInstance().asInstanceOf[StateStoreProvider]
+ providerClass.getConstructor().newInstance().asInstanceOf[StateStoreProvider]
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
index 2959eacf95dae..a656a2f53e0a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
@@ -159,7 +159,6 @@ class SQLAppStatusListener(
}
private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = {
- val metricIds = exec.metrics.map(_.accumulatorId).sorted
val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap
val metrics = exec.stages.toSeq
.flatMap { stageId => Option(stageMetrics.get(stageId)) }
@@ -167,10 +166,10 @@ class SQLAppStatusListener(
.flatMap { metrics => metrics.ids.zip(metrics.values) }
val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq)
- .filter { case (id, _) => metricIds.contains(id) }
+ .filter { case (id, _) => metricTypes.contains(id) }
.groupBy(_._1)
.map { case (id, values) =>
- id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2).toSeq)
+ id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2))
}
// Check the execution again for whether the aggregated metrics data has been calculated.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 1e076207bc607..6b4def35e1955 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -17,8 +17,8 @@
package org.apache.spark.sql.expressions
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.{Dataset, Encoder, TypedColumn}
+import org.apache.spark.annotation.{Evolving, Experimental}
+import org.apache.spark.sql.{Encoder, TypedColumn}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
@@ -51,7 +51,7 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
* @since 1.6.0
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index eb956c4b3e888..58a942afe28c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.expressions
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class UserDefinedFunction protected[sql] (
f: AnyRef,
dataType: DataType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
index 14dec8f0810f2..64375085a64ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.expressions
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions._
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions._
*
* @since 1.4.0
*/
-@InterfaceStability.Stable
+@Stable
object Window {
/**
@@ -243,5 +243,5 @@ object Window {
*
* @since 1.4.0
*/
-@InterfaceStability.Stable
+@Stable
class Window private() // So we can see Window in JavaDoc.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index 0cc43a58237df..0ace9d25f9dc7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.expressions
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.{AnalysisException, Column}
import org.apache.spark.sql.catalyst.expressions._
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
*
* @since 1.4.0
*/
-@InterfaceStability.Stable
+@Stable
class WindowSpec private[sql](
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala
index 3e637d594caf3..1cb579c4faa76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.expressions.scalalang
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.{Evolving, Experimental}
import org.apache.spark.sql._
import org.apache.spark.sql.execution.aggregate._
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.aggregate._
* @since 2.0.0
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
// scalastyle:off
object typed {
// scalastyle:on
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index 4976b875fa298..4e8cb3a6ddd66 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.expressions
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
*
* @since 1.5.0
*/
-@InterfaceStability.Stable
+@Stable
abstract class UserDefinedAggregateFunction extends Serializable {
/**
@@ -159,7 +159,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
*
* @since 1.5.0
*/
-@InterfaceStability.Stable
+@Stable
abstract class MutableAggregationBuffer extends Row {
/** Update the ith value of this buffer. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 8edc57834b2f5..274be8ae7c835 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import scala.util.Try
import scala.util.control.NonFatal
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
@@ -68,7 +68,7 @@ import org.apache.spark.util.Utils
* @groupname Ungrouped Support functions for DataFrames
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
// scalastyle:off
object functions {
// scalastyle:on
@@ -3836,7 +3836,7 @@ object functions {
/**
* Returns an unordered array of all entries in the given map.
* @group collection_funcs
- * @since 2.4.0
+ * @since 3.0.0
*/
def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index f67cc32c15dd2..ac07e1f6bb4f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.internal
import org.apache.spark.SparkConf
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.{Experimental, Unstable}
import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
@@ -50,7 +50,7 @@ import org.apache.spark.sql.util.ExecutionListenerManager
* and `catalog` fields. Note that the state is cloned when `build` is called, and not before.
*/
@Experimental
-@InterfaceStability.Unstable
+@Unstable
abstract class BaseSessionStateBuilder(
val session: SparkSession,
val parentState: Option[SessionState] = None) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index accbea41b9603..b34db581ca2c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -23,7 +23,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.annotation.{Experimental, Unstable}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog._
@@ -124,7 +124,7 @@ private[sql] object SessionState {
* Concrete implementation of a [[BaseSessionStateBuilder]].
*/
@Experimental
-@InterfaceStability.Unstable
+@Unstable
class SessionStateBuilder(
session: SparkSession,
parentState: Option[SessionState] = None)
@@ -135,7 +135,7 @@ class SessionStateBuilder(
/**
* Session shared [[FunctionResourceLoader]].
*/
-@InterfaceStability.Unstable
+@Unstable
class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoader {
override def loadResource(resource: FunctionResource): Unit = {
resource.resourceType match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index f76c1fae562c6..230b43022b02b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -21,8 +21,7 @@ import java.sql.{Connection, Date, Timestamp}
import org.apache.commons.lang3.StringUtils
-import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since}
-import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.annotation.{DeveloperApi, Evolving, Since}
import org.apache.spark.sql.types._
/**
@@ -34,7 +33,7 @@ import org.apache.spark.sql.types._
* send a null value to the database.
*/
@DeveloperApi
-@InterfaceStability.Evolving
+@Evolving
case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
/**
@@ -57,7 +56,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
* for the given Catalyst type.
*/
@DeveloperApi
-@InterfaceStability.Evolving
+@Evolving
abstract class JdbcDialect extends Serializable {
/**
* Check if this dialect instance can handle a certain jdbc url.
@@ -197,7 +196,7 @@ abstract class JdbcDialect extends Serializable {
* sure to register your dialects first.
*/
@DeveloperApi
-@InterfaceStability.Evolving
+@Evolving
object JdbcDialects {
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index 161e0102f0b43..61875931d226e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -17,7 +17,7 @@
package org.apache.spark
-import org.apache.spark.annotation.{DeveloperApi, InterfaceStability}
+import org.apache.spark.annotation.{DeveloperApi, Unstable}
import org.apache.spark.sql.execution.SparkStrategy
/**
@@ -40,8 +40,17 @@ package object sql {
* [[org.apache.spark.sql.sources]]
*/
@DeveloperApi
- @InterfaceStability.Unstable
+ @Unstable
type Strategy = SparkStrategy
type DataFrame = Dataset[Row]
+
+ /**
+ * Metadata key which is used to write Spark version in the followings:
+ * - Parquet file metadata
+ * - ORC file metadata
+ *
+ * Note that Hive table property `spark.sql.create.version` also has Spark version.
+ */
+ private[sql] val SPARK_VERSION_METADATA_KEY = "org.apache.spark.version"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
index bdd8c4da6bd30..3f941cc6e1072 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Stable
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines all the filters that we can push down to the data sources.
@@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
abstract class Filter {
/**
* List of columns that are referenced by this filter.
@@ -48,7 +48,7 @@ abstract class Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class EqualTo(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
}
@@ -60,7 +60,7 @@ case class EqualTo(attribute: String, value: Any) extends Filter {
*
* @since 1.5.0
*/
-@InterfaceStability.Stable
+@Stable
case class EqualNullSafe(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
}
@@ -71,7 +71,7 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class GreaterThan(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
}
@@ -82,7 +82,7 @@ case class GreaterThan(attribute: String, value: Any) extends Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
}
@@ -93,7 +93,7 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class LessThan(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
}
@@ -104,7 +104,7 @@ case class LessThan(attribute: String, value: Any) extends Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class LessThanOrEqual(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
}
@@ -114,7 +114,7 @@ case class LessThanOrEqual(attribute: String, value: Any) extends Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class In(attribute: String, values: Array[Any]) extends Filter {
override def hashCode(): Int = {
var h = attribute.hashCode
@@ -141,7 +141,7 @@ case class In(attribute: String, values: Array[Any]) extends Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class IsNull(attribute: String) extends Filter {
override def references: Array[String] = Array(attribute)
}
@@ -151,7 +151,7 @@ case class IsNull(attribute: String) extends Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class IsNotNull(attribute: String) extends Filter {
override def references: Array[String] = Array(attribute)
}
@@ -161,7 +161,7 @@ case class IsNotNull(attribute: String) extends Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class And(left: Filter, right: Filter) extends Filter {
override def references: Array[String] = left.references ++ right.references
}
@@ -171,7 +171,7 @@ case class And(left: Filter, right: Filter) extends Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class Or(left: Filter, right: Filter) extends Filter {
override def references: Array[String] = left.references ++ right.references
}
@@ -181,7 +181,7 @@ case class Or(left: Filter, right: Filter) extends Filter {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
case class Not(child: Filter) extends Filter {
override def references: Array[String] = child.references
}
@@ -192,7 +192,7 @@ case class Not(child: Filter) extends Filter {
*
* @since 1.3.1
*/
-@InterfaceStability.Stable
+@Stable
case class StringStartsWith(attribute: String, value: String) extends Filter {
override def references: Array[String] = Array(attribute)
}
@@ -203,7 +203,7 @@ case class StringStartsWith(attribute: String, value: String) extends Filter {
*
* @since 1.3.1
*/
-@InterfaceStability.Stable
+@Stable
case class StringEndsWith(attribute: String, value: String) extends Filter {
override def references: Array[String] = Array(attribute)
}
@@ -214,7 +214,7 @@ case class StringEndsWith(attribute: String, value: String) extends Filter {
*
* @since 1.3.1
*/
-@InterfaceStability.Stable
+@Stable
case class StringContains(attribute: String, value: String) extends Filter {
override def references: Array[String] = Array(attribute)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 6057a795c8bf5..6ad054c9f6403 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources
-import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
+import org.apache.spark.annotation._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
@@ -35,7 +35,7 @@ import org.apache.spark.sql.types.StructType
*
* @since 1.5.0
*/
-@InterfaceStability.Stable
+@Stable
trait DataSourceRegister {
/**
@@ -65,7 +65,7 @@ trait DataSourceRegister {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
trait RelationProvider {
/**
* Returns a new base relation with the given parameters.
@@ -96,7 +96,7 @@ trait RelationProvider {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
trait SchemaRelationProvider {
/**
* Returns a new base relation with the given parameters and user defined schema.
@@ -117,7 +117,7 @@ trait SchemaRelationProvider {
* @since 2.0.0
*/
@Experimental
-@InterfaceStability.Unstable
+@Unstable
trait StreamSourceProvider {
/**
@@ -148,7 +148,7 @@ trait StreamSourceProvider {
* @since 2.0.0
*/
@Experimental
-@InterfaceStability.Unstable
+@Unstable
trait StreamSinkProvider {
def createSink(
sqlContext: SQLContext,
@@ -160,7 +160,7 @@ trait StreamSinkProvider {
/**
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
trait CreatableRelationProvider {
/**
* Saves a DataFrame to a destination (using data source-specific parameters)
@@ -192,7 +192,7 @@ trait CreatableRelationProvider {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
abstract class BaseRelation {
def sqlContext: SQLContext
def schema: StructType
@@ -242,7 +242,7 @@ abstract class BaseRelation {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
trait TableScan {
def buildScan(): RDD[Row]
}
@@ -253,7 +253,7 @@ trait TableScan {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
trait PrunedScan {
def buildScan(requiredColumns: Array[String]): RDD[Row]
}
@@ -271,7 +271,7 @@ trait PrunedScan {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
trait PrunedFilteredScan {
def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row]
}
@@ -293,7 +293,7 @@ trait PrunedFilteredScan {
*
* @since 1.3.0
*/
-@InterfaceStability.Stable
+@Stable
trait InsertableRelation {
def insert(data: DataFrame, overwrite: Boolean): Unit
}
@@ -309,7 +309,7 @@ trait InsertableRelation {
* @since 1.3.0
*/
@Experimental
-@InterfaceStability.Unstable
+@Unstable
trait CatalystScan {
def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 4c7dcedafeeae..c8e3e1c191044 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -21,7 +21,7 @@ import java.util.Locale
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.execution.command.DDLUtils
@@ -40,7 +40,7 @@ import org.apache.spark.util.Utils
*
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging {
/**
* Specifies the input data source format.
@@ -158,7 +158,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
"read files of Hive data source directly.")
}
- val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).newInstance()
+ val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf).
+ getConstructor().newInstance()
// We need to generate the V1 data source so we can pass it to the V2 relation as a shim.
// We can't be sure at this point whether we'll actually want to use V2, since we don't know the
// writer or whether the query is continuous.
@@ -296,6 +297,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* that should be used for parsing.
*
`dropFieldIfAllNull` (default `false`): whether to ignore column of all null values or
* empty array/struct during schema inference.
+ *
`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
+ * For instance, this is used while parsing dates and timestamps.
*
*
* @since 2.0.0
@@ -372,6 +375,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string
* created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
*
`multiLine` (default `false`): parse one record, which may span multiple lines.
+ *
`locale` (default is `en-US`): sets a locale as language tag in IETF BCP 47 format.
+ * For instance, this is used while parsing dates and timestamps.
+ *
`lineSep` (default covers all `\r`, `\r\n` and `\n`): defines the line separator
+ * that should be used for parsing. Maximum length is 1 character.
*
*
* @since 2.0.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 4a8c7fdb58ff1..5733258a6b310 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -21,7 +21,7 @@ import java.util.Locale
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
@@ -39,7 +39,7 @@ import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider
*
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
private val df = ds.toDF()
@@ -307,7 +307,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
var options = extraOptions.toMap
- val sink = ds.newInstance() match {
+ val sink = ds.getConstructor().newInstance() match {
case w: StreamingWriteSupportProvider
if !disabledSources.contains(w.getClass.getCanonicalName) =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
@@ -365,7 +365,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
*
* @since 2.4.0
*/
- @InterfaceStability.Evolving
+ @Evolving
def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = {
this.source = "foreachBatch"
if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null")
@@ -386,7 +386,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
*
* @since 2.4.0
*/
- @InterfaceStability.Evolving
+ @Evolving
def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = {
foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
index e9510c903acae..ab68eba81b843 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.streaming
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.KeyValueGroupedDataset
+import org.apache.spark.annotation.{Evolving, Experimental}
import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
/**
@@ -192,7 +191,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
* @since 2.2.0
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
trait GroupState[S] extends LogicalGroupState[S] {
/** Whether state exists or not. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala
index a033575d3d38f..236bd55ee6212 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala
@@ -23,7 +23,7 @@ import scala.concurrent.duration.Duration
import org.apache.commons.lang3.StringUtils
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
import org.apache.spark.unsafe.types.CalendarInterval
/**
@@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.CalendarInterval
*
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0")
case class ProcessingTime(intervalMs: Long) extends Trigger {
require(intervalMs >= 0, "the interval of trigger should not be negative")
@@ -59,7 +59,7 @@ case class ProcessingTime(intervalMs: Long) extends Trigger {
*
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
@deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0")
object ProcessingTime {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
index f2dfbe42260d7..47ddc88e964e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming
import java.util.UUID
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
import org.apache.spark.sql.SparkSession
/**
@@ -27,7 +27,7 @@ import org.apache.spark.sql.SparkSession
* All these methods are thread-safe.
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
trait StreamingQuery {
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
index 03aeb14de502a..646d6888b2a16 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.streaming
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
/**
* Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception
@@ -28,7 +28,7 @@ import org.apache.spark.annotation.InterfaceStability
* @param endOffset Ending offset in json of the range of data in exception occurred
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
class StreamingQueryException private[sql](
private val queryDebugString: String,
val message: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
index 6aa82b89ede81..916d6a0365965 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming
import java.util.UUID
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
import org.apache.spark.scheduler.SparkListenerEvent
/**
@@ -28,7 +28,7 @@ import org.apache.spark.scheduler.SparkListenerEvent
*
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
abstract class StreamingQueryListener {
import StreamingQueryListener._
@@ -67,14 +67,14 @@ abstract class StreamingQueryListener {
* Companion object of [[StreamingQueryListener]] that defines the listener events.
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
object StreamingQueryListener {
/**
* Base type of [[StreamingQueryListener]] events
* @since 2.0.0
*/
- @InterfaceStability.Evolving
+ @Evolving
trait Event extends SparkListenerEvent
/**
@@ -84,7 +84,7 @@ object StreamingQueryListener {
* @param name User-specified name of the query, null if not specified.
* @since 2.1.0
*/
- @InterfaceStability.Evolving
+ @Evolving
class QueryStartedEvent private[sql](
val id: UUID,
val runId: UUID,
@@ -95,7 +95,7 @@ object StreamingQueryListener {
* @param progress The query progress updates.
* @since 2.1.0
*/
- @InterfaceStability.Evolving
+ @Evolving
class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event
/**
@@ -107,7 +107,7 @@ object StreamingQueryListener {
* with an exception. Otherwise, it will be `None`.
* @since 2.1.0
*/
- @InterfaceStability.Evolving
+ @Evolving
class QueryTerminatedEvent private[sql](
val id: UUID,
val runId: UUID,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index cd52d991d55c9..d9fe1a992a093 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
@@ -42,7 +42,7 @@ import org.apache.spark.util.{Clock, SystemClock, Utils}
*
* @since 2.0.0
*/
-@InterfaceStability.Evolving
+@Evolving
class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging {
private[sql] val stateStoreCoordinator =
@@ -311,7 +311,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
outputMode: OutputMode,
useTempCheckpointLocation: Boolean = false,
recoverFromCheckpointLocation: Boolean = true,
- trigger: Trigger = ProcessingTime(0),
+ trigger: Trigger = Trigger.ProcessingTime(0),
triggerClock: Clock = new SystemClock()): StreamingQuery = {
val query = createQuery(
userSpecifiedName,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
index a0c9bcc8929eb..9dc62b7aac891 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
@@ -22,7 +22,7 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
/**
* Reports information about the instantaneous status of a streaming query.
@@ -34,7 +34,7 @@ import org.apache.spark.annotation.InterfaceStability
*
* @since 2.1.0
*/
-@InterfaceStability.Evolving
+@Evolving
class StreamingQueryStatus protected[sql](
val message: String,
val isDataAvailable: Boolean,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
index f2173aa1e59c2..3cd6700efef5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
@@ -29,12 +29,12 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.annotation.Evolving
/**
* Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger.
*/
-@InterfaceStability.Evolving
+@Evolving
class StateOperatorProgress private[sql](
val numRowsTotal: Long,
val numRowsUpdated: Long,
@@ -94,7 +94,7 @@ class StateOperatorProgress private[sql](
* @param sources detailed statistics on data being read from each of the streaming sources.
* @since 2.1.0
*/
-@InterfaceStability.Evolving
+@Evolving
class StreamingQueryProgress private[sql](
val id: UUID,
val runId: UUID,
@@ -165,7 +165,7 @@ class StreamingQueryProgress private[sql](
* Spark.
* @since 2.1.0
*/
-@InterfaceStability.Evolving
+@Evolving
class SourceProgress protected[sql](
val description: String,
val startOffset: String,
@@ -209,7 +209,7 @@ class SourceProgress protected[sql](
* @param description Description of the source corresponding to this status.
* @since 2.1.0
*/
-@InterfaceStability.Evolving
+@Evolving
class SinkProgress protected[sql](
val description: String) extends Serializable {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index 8bab7e1c58762..7beac16599de5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -45,7 +45,7 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
override def sqlType: DataType = ArrayType(DoubleType, false)
- override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
+ override def pyUDT: String = "pyspark.testing.sqlutils.ExamplePointUDT"
override def serialize(p: ExamplePoint): GenericArrayData = {
val output = new Array[Any](2)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
index 1310fdfa1356b..77ae047705de0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.util
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
+import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.SparkSession
@@ -36,7 +36,7 @@ import org.apache.spark.util.{ListenerBus, Utils}
* multiple different threads.
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
trait QueryExecutionListener {
/**
@@ -73,7 +73,7 @@ trait QueryExecutionListener {
* Manager for [[QueryExecutionListener]]. See `org.apache.spark.sql.SQLContext.listenerManager`.
*/
@Experimental
-@InterfaceStability.Evolving
+@Evolving
// The `session` is used to indicate which session carries this listener manager, and we only
// catch SQL executions which are launched by the same session.
// The `loadExtensions` flag is used to indicate whether we should load the pre-defined,
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
index 7f975a647c241..8f35abeb579b5 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
@@ -143,11 +143,16 @@ public void setIntervals(List intervals) {
this.intervals = intervals;
}
+ @Override
+ public int hashCode() {
+ return id ^ Objects.hashCode(intervals);
+ }
+
@Override
public boolean equals(Object obj) {
if (!(obj instanceof ArrayRecord)) return false;
ArrayRecord other = (ArrayRecord) obj;
- return (other.id == this.id) && other.intervals.equals(this.intervals);
+ return (other.id == this.id) && Objects.equals(other.intervals, this.intervals);
}
@Override
@@ -184,6 +189,11 @@ public void setIntervals(Map intervals) {
this.intervals = intervals;
}
+ @Override
+ public int hashCode() {
+ return id ^ Objects.hashCode(intervals);
+ }
+
@Override
public boolean equals(Object obj) {
if (!(obj instanceof MapRecord)) return false;
@@ -225,6 +235,11 @@ public void setEndTime(long endTime) {
this.endTime = endTime;
}
+ @Override
+ public int hashCode() {
+ return Long.hashCode(startTime) ^ Long.hashCode(endTime);
+ }
+
@Override
public boolean equals(Object obj) {
if (!(obj instanceof Interval)) return false;
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java
index 3ab4db2a035d3..ca78d6489ef5c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java
@@ -67,20 +67,20 @@ public void setUp() {
public void constructSimpleRow() {
Row simpleRow = RowFactory.create(
byteValue, // ByteType
- new Byte(byteValue),
+ Byte.valueOf(byteValue),
shortValue, // ShortType
- new Short(shortValue),
+ Short.valueOf(shortValue),
intValue, // IntegerType
- new Integer(intValue),
+ Integer.valueOf(intValue),
longValue, // LongType
- new Long(longValue),
+ Long.valueOf(longValue),
floatValue, // FloatType
- new Float(floatValue),
+ Float.valueOf(floatValue),
doubleValue, // DoubleType
- new Double(doubleValue),
+ Double.valueOf(doubleValue),
decimalValue, // DecimalType
booleanValue, // BooleanType
- new Boolean(booleanValue),
+ Boolean.valueOf(booleanValue),
stringValue, // StringType
binaryValue, // BinaryType
dateValue, // DateType
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java
index b90224f2ae397..5955eabe496df 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java
@@ -25,6 +25,6 @@
public class JavaStringLength implements UDF1 {
@Override
public Integer call(String str) throws Exception {
- return new Integer(str.length());
+ return Integer.valueOf(str.length());
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java
new file mode 100644
index 0000000000000..438f489a3eea7
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaRangeInputPartition.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources.v2;
+
+import org.apache.spark.sql.sources.v2.reader.InputPartition;
+
+class JavaRangeInputPartition implements InputPartition {
+ int start;
+ int end;
+
+ JavaRangeInputPartition(int start, int end) {
+ this.start = start;
+ this.end = end;
+ }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java
index 685f9b9747e85..ced51dde6997b 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java
@@ -88,12 +88,3 @@ public void close() throws IOException {
}
}
-class JavaRangeInputPartition implements InputPartition {
- int start;
- int end;
-
- JavaRangeInputPartition(int start, int end) {
- this.start = start;
- this.end = end;
- }
-}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql
index 69da67fc66fc0..60895020fcc83 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql
@@ -13,7 +13,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES (
map('a', 'b'), map('c', 'd'),
map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')),
map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)),
- map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)),
map('a', 1), map('c', 2),
map(1, 'a'), map(2, 'c')
) AS various_maps (
@@ -31,7 +30,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES (
string_map1, string_map2,
array_map1, array_map2,
struct_map1, struct_map2,
- map_map1, map_map2,
string_int_map1, string_int_map2,
int_string_map1, int_string_map2
);
@@ -51,7 +49,6 @@ SELECT
map_concat(string_map1, string_map2) string_map,
map_concat(array_map1, array_map2) array_map,
map_concat(struct_map1, struct_map2) struct_map,
- map_concat(map_map1, map_map2) map_map,
map_concat(string_int_map1, string_int_map2) string_int_map,
map_concat(int_string_map1, int_string_map2) int_string_map
FROM various_maps;
@@ -71,7 +68,7 @@ FROM various_maps;
-- Concatenate map of incompatible types 1
SELECT
- map_concat(tinyint_map1, map_map2) tm_map
+ map_concat(tinyint_map1, array_map1) tm_map
FROM various_maps;
-- Concatenate map of incompatible types 2
@@ -86,10 +83,10 @@ FROM various_maps;
-- Concatenate map of incompatible types 4
SELECT
- map_concat(map_map1, array_map2) ma_map
+ map_concat(struct_map1, array_map2) ma_map
FROM various_maps;
-- Concatenate map of incompatible types 5
SELECT
- map_concat(map_map1, struct_map2) ms_map
+ map_concat(int_map1, array_map2) ms_map
FROM various_maps;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql
index cda4db4b449fe..faab4c61c8640 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/window.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql
@@ -109,3 +109,9 @@ last_value(false, false) OVER w AS last_value_contain_null
FROM testData
WINDOW w AS ()
ORDER BY cate, val;
+
+-- parentheses around window reference
+SELECT cate, sum(val) OVER (w)
+FROM testData
+WHERE val is not null
+WINDOW w AS (PARTITION BY cate ORDER BY val);
diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
index 270eee8680225..17dd317f63b70 100644
--- a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
@@ -93,7 +93,7 @@ Partition Values [ds=2017-08-01, hr=10]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
Created Time [not included in comparison]
Last Access [not included in comparison]
-Partition Statistics 1264 bytes, 3 rows
+Partition Statistics [not included in comparison] bytes, 3 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -128,7 +128,7 @@ Partition Values [ds=2017-08-01, hr=10]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
Created Time [not included in comparison]
Last Access [not included in comparison]
-Partition Statistics 1264 bytes, 3 rows
+Partition Statistics [not included in comparison] bytes, 3 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -155,7 +155,7 @@ Partition Values [ds=2017-08-01, hr=11]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11
Created Time [not included in comparison]
Last Access [not included in comparison]
-Partition Statistics 1278 bytes, 4 rows
+Partition Statistics [not included in comparison] bytes, 4 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -190,7 +190,7 @@ Partition Values [ds=2017-08-01, hr=10]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
Created Time [not included in comparison]
Last Access [not included in comparison]
-Partition Statistics 1264 bytes, 3 rows
+Partition Statistics [not included in comparison] bytes, 3 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -217,7 +217,7 @@ Partition Values [ds=2017-08-01, hr=11]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11
Created Time [not included in comparison]
Last Access [not included in comparison]
-Partition Statistics 1278 bytes, 4 rows
+Partition Statistics [not included in comparison] bytes, 4 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
@@ -244,7 +244,7 @@ Partition Values [ds=2017-09-01, hr=5]
Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5
Created Time [not included in comparison]
Last Access [not included in comparison]
-Partition Statistics 1250 bytes, 2 rows
+Partition Statistics [not included in comparison] bytes, 2 rows
# Storage Information
Location [not included in comparison]sql/core/spark-warehouse/t
diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
index fd1d0db9e3f78..570b281353f3d 100644
--- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
@@ -201,7 +201,7 @@ struct
-- !query 24 output
== Physical Plan ==
*Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x]
-+- Scan OneRowRelation[]
++- *Scan OneRowRelation[]
-- !query 25
@@ -211,7 +211,7 @@ struct
-- !query 25 output
== Physical Plan ==
*Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x]
-+- Scan OneRowRelation[]
++- *Scan OneRowRelation[]
-- !query 26
@@ -221,7 +221,7 @@ struct
-- !query 26 output
== Physical Plan ==
*Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x]
-+- Scan OneRowRelation[]
++- *Scan OneRowRelation[]
-- !query 27
@@ -231,7 +231,7 @@ struct
-- !query 27 output
== Physical Plan ==
*Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x]
-+- Scan OneRowRelation[]
++- *Scan OneRowRelation[]
-- !query 28
@@ -241,7 +241,7 @@ struct
-- !query 28 output
== Physical Plan ==
*Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x]
-+- Scan OneRowRelation[]
++- *Scan OneRowRelation[]
-- !query 29
@@ -251,7 +251,7 @@ struct
-- !query 29 output
== Physical Plan ==
*Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x]
-+- Scan OneRowRelation[]
++- *Scan OneRowRelation[]
-- !query 30
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out
index efc88e47209a6..79e00860e4c05 100644
--- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out
@@ -18,7 +18,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES (
map('a', 'b'), map('c', 'd'),
map(array('a', 'b'), array('c', 'd')), map(array('e'), array('f')),
map(struct('a', 1), struct('b', 2)), map(struct('c', 3), struct('d', 4)),
- map(map('a', 1), map('b', 2)), map(map('c', 3), map('d', 4)),
map('a', 1), map('c', 2),
map(1, 'a'), map(2, 'c')
) AS various_maps (
@@ -36,7 +35,6 @@ CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES (
string_map1, string_map2,
array_map1, array_map2,
struct_map1, struct_map2,
- map_map1, map_map2,
string_int_map1, string_int_map2,
int_string_map1, int_string_map2
)
@@ -61,14 +59,13 @@ SELECT
map_concat(string_map1, string_map2) string_map,
map_concat(array_map1, array_map2) array_map,
map_concat(struct_map1, struct_map2) struct_map,
- map_concat(map_map1, map_map2) map_map,
map_concat(string_int_map1, string_int_map2) string_int_map,
map_concat(int_string_map1, int_string_map2) int_string_map
FROM various_maps
-- !query 1 schema
-struct,tinyint_map:map,smallint_map:map,int_map:map,bigint_map:map,decimal_map:map,float_map:map,double_map:map,date_map:map,timestamp_map:map,string_map:map,array_map:map,array>,struct_map:map,struct>,map_map:map