diff --git a/.asf.yaml b/.asf.yaml
index 16cdf8bfed322..ae5e99cf230d8 100644
--- a/.asf.yaml
+++ b/.asf.yaml
@@ -31,3 +31,8 @@ github:
merge: false
squash: true
rebase: true
+
+notifications:
+ pullrequests: reviews@spark.apache.org
+ issues: reviews@spark.apache.org
+ commits: commits@spark.apache.org
diff --git a/.github/labeler.yml b/.github/labeler.yml
index bd61902925e33..afaeeecda51a2 100644
--- a/.github/labeler.yml
+++ b/.github/labeler.yml
@@ -84,12 +84,12 @@ SPARK SHELL:
- "repl/**/*"
- "bin/spark-shell*"
SQL:
-#- any: ["**/sql/**/*", "!python/pyspark/sql/avro/**/*", "!python/pyspark/sql/streaming.py", "!python/pyspark/sql/tests/test_streaming.py"]
+#- any: ["**/sql/**/*", "!python/pyspark/sql/avro/**/*", "!python/pyspark/sql/streaming/**/*", "!python/pyspark/sql/tests/streaming/test_streaming.py"]
- "**/sql/**/*"
- "common/unsafe/**/*"
#- "!python/pyspark/sql/avro/**/*"
- #- "!python/pyspark/sql/streaming.py"
- #- "!python/pyspark/sql/tests/test_streaming.py"
+ #- "!python/pyspark/sql/streaming/**/*"
+ #- "!python/pyspark/sql/tests/streaming/test_streaming.py"
- "bin/spark-sql*"
- "bin/beeline*"
- "sbin/*thriftserver*.sh"
@@ -103,7 +103,7 @@ SQL:
- "**/*schema.R"
- "**/*types.R"
AVRO:
- - "external/avro/**/*"
+ - "connector/avro/**/*"
- "python/pyspark/sql/avro/**/*"
DSTREAM:
- "streaming/**/*"
@@ -123,13 +123,15 @@ MLLIB:
- "python/pyspark/mllib/**/*"
STRUCTURED STREAMING:
- "**/sql/**/streaming/**/*"
- - "external/kafka-0-10-sql/**/*"
- - "python/pyspark/sql/streaming.py"
- - "python/pyspark/sql/tests/test_streaming.py"
+ - "connector/kafka-0-10-sql/**/*"
+ - "python/pyspark/sql/streaming/**/*"
+ - "python/pyspark/sql/tests/streaming/test_streaming.py"
- "**/*streaming.R"
PYTHON:
- "bin/pyspark*"
- "**/python/**/*"
+PANDAS API ON SPARK:
+ - "python/pyspark/pandas/**/*"
R:
- "**/r/**/*"
- "**/R/**/*"
@@ -149,4 +151,10 @@ WEB UI:
- "**/*UI.scala"
DEPLOY:
- "sbin/**/*"
-
+CONNECT:
+ - "connector/connect/**/*"
+ - "**/sql/sparkconnect/**/*"
+ - "python/pyspark/sql/**/connect/**/*"
+PROTOBUF:
+ - "connector/protobuf/**/*"
+ - "python/pyspark/sql/protobuf/**/*"
diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index 91e168210fb30..8671cff054bb8 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -30,6 +30,10 @@ on:
description: 'JDK version: 8, 11 or 17'
required: true
default: '8'
+ scala:
+ description: 'Scala version: 2.12 or 2.13'
+ required: true
+ default: '2.12'
failfast:
description: 'Failfast: true or false'
required: true
@@ -50,11 +54,69 @@ jobs:
steps:
- name: Generate matrix
id: set-matrix
- run: echo "::set-output name=matrix::["`seq -s, 1 $SPARK_BENCHMARK_NUM_SPLITS`"]"
+ run: echo "matrix=["`seq -s, 1 $SPARK_BENCHMARK_NUM_SPLITS`"]" >> $GITHUB_OUTPUT
+
+ # Any TPC-DS related updates on this job need to be applied to tpcds-1g job of build_and_test.yml as well
+ tpcds-1g-gen:
+ name: "Generate an input dataset for TPCDSQueryBenchmark with SF=1"
+ if: contains(github.event.inputs.class, 'TPCDSQueryBenchmark') || contains(github.event.inputs.class, '*')
+ runs-on: ubuntu-20.04
+ env:
+ SPARK_LOCAL_IP: localhost
+ steps:
+ - name: Checkout Spark repository
+ uses: actions/checkout@v3
+ # In order to get diff files
+ with:
+ fetch-depth: 0
+ - name: Cache Scala, SBT and Maven
+ uses: actions/cache@v3
+ with:
+ path: |
+ build/apache-maven-*
+ build/scala-*
+ build/*.jar
+ ~/.sbt
+ key: build-${{ hashFiles('**/pom.xml', 'project/build.properties', 'build/mvn', 'build/sbt', 'build/sbt-launch-lib.bash', 'build/spark-build-info') }}
+ restore-keys: |
+ build-
+ - name: Cache Coursier local repository
+ uses: actions/cache@v3
+ with:
+ path: ~/.cache/coursier
+ key: benchmark-coursier-${{ github.event.inputs.jdk }}-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
+ restore-keys: |
+ benchmark-coursier-${{ github.event.inputs.jdk }}
+ - name: Cache TPC-DS generated data
+ id: cache-tpcds-sf-1
+ uses: actions/cache@v3
+ with:
+ path: ./tpcds-sf-1
+ key: tpcds-${{ hashFiles('.github/workflows/benchmark.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }}
+ - name: Checkout tpcds-kit repository
+ if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true'
+ uses: actions/checkout@v3
+ with:
+ repository: databricks/tpcds-kit
+ ref: 2a5078a782192ddb6efbcead8de9973d6ab4f069
+ path: ./tpcds-kit
+ - name: Build tpcds-kit
+ if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true'
+ run: cd tpcds-kit/tools && make OS=LINUX
+ - name: Install Java ${{ github.event.inputs.jdk }}
+ if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true'
+ uses: actions/setup-java@v3
+ with:
+ distribution: temurin
+ java-version: ${{ github.event.inputs.jdk }}
+ - name: Generate TPC-DS (SF=1) table data
+ if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true'
+ run: build/sbt "sql/Test/runMain org.apache.spark.sql.GenTPCDSData --dsdgenDir `pwd`/tpcds-kit/tools --location `pwd`/tpcds-sf-1 --scaleFactor 1 --numPartitions 1 --overwrite"
benchmark:
- name: "Run benchmarks: ${{ github.event.inputs.class }} (JDK ${{ github.event.inputs.jdk }}, ${{ matrix.split }} out of ${{ github.event.inputs.num-splits }} splits)"
- needs: matrix-gen
+ name: "Run benchmarks: ${{ github.event.inputs.class }} (JDK ${{ github.event.inputs.jdk }}, Scala ${{ github.event.inputs.scala }}, ${{ matrix.split }} out of ${{ github.event.inputs.num-splits }} splits)"
+ if: always()
+ needs: [matrix-gen, tpcds-1g-gen]
# Ubuntu 20.04 is the latest LTS. The next LTS is 22.04.
runs-on: ubuntu-20.04
strategy:
@@ -69,14 +131,15 @@ jobs:
SPARK_LOCAL_IP: localhost
# To prevent spark.test.home not being set. See more detail in SPARK-36007.
SPARK_HOME: ${{ github.workspace }}
+ SPARK_TPCDS_DATA: ${{ github.workspace }}/tpcds-sf-1
steps:
- name: Checkout Spark repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
# In order to get diff files
with:
fetch-depth: 0
- name: Cache Scala, SBT and Maven
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: |
build/apache-maven-*
@@ -87,19 +150,28 @@ jobs:
restore-keys: |
build-
- name: Cache Coursier local repository
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/coursier
key: benchmark-coursier-${{ github.event.inputs.jdk }}-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
restore-keys: |
benchmark-coursier-${{ github.event.inputs.jdk }}
- name: Install Java ${{ github.event.inputs.jdk }}
- uses: actions/setup-java@v1
+ uses: actions/setup-java@v3
with:
+ distribution: temurin
java-version: ${{ github.event.inputs.jdk }}
+ - name: Cache TPC-DS generated data
+ if: contains(github.event.inputs.class, 'TPCDSQueryBenchmark') || contains(github.event.inputs.class, '*')
+ id: cache-tpcds-sf-1
+ uses: actions/cache@v3
+ with:
+ path: ./tpcds-sf-1
+ key: tpcds-${{ hashFiles('.github/workflows/benchmark.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }}
- name: Run benchmarks
run: |
- ./build/sbt -Pyarn -Pmesos -Pkubernetes -Phive -Phive-thriftserver -Phadoop-cloud -Pkinesis-asl -Pspark-ganglia-lgpl test:package
+ dev/change-scala-version.sh ${{ github.event.inputs.scala }}
+ ./build/sbt -Pscala-${{ github.event.inputs.scala }} -Pyarn -Pmesos -Pkubernetes -Phive -Phive-thriftserver -Phadoop-cloud -Pkinesis-asl -Pspark-ganglia-lgpl Test/package
# Make less noisy
cp conf/log4j2.properties.template conf/log4j2.properties
sed -i 's/rootLogger.level = info/rootLogger.level = warn/g' conf/log4j2.properties
@@ -109,13 +181,15 @@ jobs:
--jars "`find . -name '*-SNAPSHOT-tests.jar' -o -name '*avro*-SNAPSHOT.jar' | paste -sd ',' -`" \
"`find . -name 'spark-core*-SNAPSHOT-tests.jar'`" \
"${{ github.event.inputs.class }}"
+ # Revert to default Scala version to clean up unnecessary git diff
+ dev/change-scala-version.sh 2.12
# To keep the directory structure and file permissions, tar them
# See also https://github.com/actions/upload-artifact#maintaining-file-permissions-and-case-sensitive-files
echo "Preparing the benchmark results:"
- tar -cvf benchmark-results-${{ github.event.inputs.jdk }}.tar `git diff --name-only` `git ls-files --others --exclude-standard`
+ tar -cvf benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}.tar `git diff --name-only` `git ls-files --others --exclude=tpcds-sf-1 --exclude-standard`
- name: Upload benchmark results
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
- name: benchmark-results-${{ github.event.inputs.jdk }}-${{ matrix.split }}
- path: benchmark-results-${{ github.event.inputs.jdk }}.tar
+ name: benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}-${{ matrix.split }}
+ path: benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}.tar
diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index a392f940df99d..29a9a58de08a8 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -20,74 +20,35 @@
name: Build and test
on:
- push:
- branches:
- - '**'
workflow_call:
inputs:
- ansi_enabled:
+ java:
required: false
- type: boolean
- default: false
-
+ type: string
+ default: 8
+ branch:
+ description: Branch to run the build against
+ required: false
+ type: string
+ default: branch-3.4
+ hadoop:
+ description: Hadoop version to run with. HADOOP_PROFILE environment variable should accept it.
+ required: false
+ type: string
+ default: hadoop3
+ envs:
+ description: Additional environment variables to set when running the tests. Should be in JSON format.
+ required: false
+ type: string
+ default: '{}'
+ jobs:
+ description: >-
+ Jobs to run, and should be in JSON format. The values should be matched with the job's key defined
+ in this file, e.g., build. See precondition job below.
+ required: false
+ type: string
+ default: ''
jobs:
- configure-jobs:
- name: Configure jobs
- runs-on: ubuntu-20.04
- outputs:
- java: ${{ steps.set-outputs.outputs.java }}
- branch: ${{ steps.set-outputs.outputs.branch }}
- hadoop: ${{ steps.set-outputs.outputs.hadoop }}
- type: ${{ steps.set-outputs.outputs.type }}
- envs: ${{ steps.set-outputs.outputs.envs }}
- steps:
- - name: Configure branch and additional environment variables
- id: set-outputs
- run: |
- if [ "${{ github.event.schedule }}" = "0 1 * * *" ]; then
- echo '::set-output name=java::8'
- echo '::set-output name=branch::master'
- echo '::set-output name=type::scheduled'
- echo '::set-output name=envs::{}'
- echo '::set-output name=hadoop::hadoop2'
- elif [ "${{ github.event.schedule }}" = "0 4 * * *" ]; then
- echo '::set-output name=java::8'
- echo '::set-output name=branch::master'
- echo '::set-output name=type::scheduled'
- echo '::set-output name=envs::{"SCALA_PROFILE": "scala2.13"}'
- echo '::set-output name=hadoop::hadoop3'
- elif [ "${{ github.event.schedule }}" = "0 7 * * *" ]; then
- echo '::set-output name=java::8'
- echo '::set-output name=branch::branch-3.2'
- echo '::set-output name=type::scheduled'
- echo '::set-output name=envs::{"SCALA_PROFILE": "scala2.13"}'
- echo '::set-output name=hadoop::hadoop3.2'
- elif [ "${{ github.event.schedule }}" = "0 10 * * *" ]; then
- echo '::set-output name=java::8'
- echo '::set-output name=branch::master'
- echo '::set-output name=type::pyspark-coverage-scheduled'
- echo '::set-output name=envs::{"PYSPARK_CODECOV": "true"}'
- echo '::set-output name=hadoop::hadoop3'
- elif [ "${{ github.event.schedule }}" = "0 13 * * *" ]; then
- echo '::set-output name=java::11'
- echo '::set-output name=branch::master'
- echo '::set-output name=type::scheduled'
- echo '::set-output name=envs::{"SKIP_MIMA": "true", "SKIP_UNIDOC": "true"}'
- echo '::set-output name=hadoop::hadoop3'
- elif [ "${{ github.event.schedule }}" = "0 16 * * *" ]; then
- echo '::set-output name=java::17'
- echo '::set-output name=branch::master'
- echo '::set-output name=type::scheduled'
- echo '::set-output name=envs::{"SKIP_MIMA": "true", "SKIP_UNIDOC": "true"}'
- echo '::set-output name=hadoop::hadoop3'
- else
- echo '::set-output name=java::8'
- echo '::set-output name=branch::branch-3.3' # Default branch to run on. CHANGE here when a branch is cut out.
- echo '::set-output name=type::regular'
- echo '::set-output name=envs::{"SPARK_ANSI_SQL_MODE": "${{ inputs.ansi_enabled }}"}'
- echo '::set-output name=hadoop::hadoop3'
- fi
-
precondition:
name: Check changes
runs-on: ubuntu-20.04
@@ -95,50 +56,86 @@ jobs:
GITHUB_PREV_SHA: ${{ github.event.before }}
outputs:
required: ${{ steps.set-outputs.outputs.required }}
+ image_url: >-
+ ${{
+ (inputs.branch == 'branch-3.4' && steps.infra-image-outputs.outputs.image_url)
+ || 'dongjoon/apache-spark-github-action-image:20220207'
+ }}
steps:
- name: Checkout Spark repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
with:
fetch-depth: 0
repository: apache/spark
- ref: branch-3.3
+ ref: ${{ inputs.branch }}
- name: Sync the current branch with the latest in Apache Spark
if: github.repository != 'apache/spark'
run: |
echo "APACHE_SPARK_REF=$(git rev-parse HEAD)" >> $GITHUB_ENV
git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
- git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit"
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
- name: Check all modules
id: set-outputs
run: |
- build=`./dev/is-changed.py -m avro,build,catalyst,core,docker-integration-tests,examples,graphx,hadoop-cloud,hive,hive-thriftserver,kubernetes,kvstore,launcher,mesos,mllib,mllib-local,network-common,network-shuffle,pyspark-core,pyspark-ml,pyspark-mllib,pyspark-pandas,pyspark-pandas-slow,pyspark-resource,pyspark-sql,pyspark-streaming,repl,sketch,spark-ganglia-lgpl,sparkr,sql,sql-kafka-0-10,streaming,streaming-kafka-0-10,streaming-kinesis-asl,tags,unsafe,yarn`
- pyspark=`./dev/is-changed.py -m avro,build,catalyst,core,graphx,hive,kvstore,launcher,mllib,mllib-local,network-common,network-shuffle,pyspark-core,pyspark-ml,pyspark-mllib,pyspark-pandas,pyspark-pandas-slow,pyspark-resource,pyspark-sql,pyspark-streaming,repl,sketch,sql,tags,unsafe`
- sparkr=`./dev/is-changed.py -m avro,build,catalyst,core,hive,kvstore,launcher,mllib,mllib-local,network-common,network-shuffle,repl,sketch,sparkr,sql,tags,unsafe`
- tpcds=`./dev/is-changed.py -m build,catalyst,core,hive,kvstore,launcher,network-common,network-shuffle,repl,sketch,sql,tags,unsafe`
- docker=`./dev/is-changed.py -m build,catalyst,core,docker-integration-tests,hive,kvstore,launcher,network-common,network-shuffle,repl,sketch,sql,tags,unsafe`
- echo "{\"build\": \"$build\", \"pyspark\": \"$pyspark\", \"sparkr\": \"$sparkr\", \"tpcds\": \"$tpcds\", \"docker\": \"$docker\"}" > required.json
- cat required.json
- echo "::set-output name=required::$(cat required.json)"
+ if [ -z "${{ inputs.jobs }}" ]; then
+ # is-changed.py is missing in branch-3.2, and it might run in scheduled build, see also SPARK-39517
+ pyspark=true; sparkr=true; tpcds=true; docker=true;
+ if [ -f "./dev/is-changed.py" ]; then
+ pyspark_modules=`cd dev && python -c "import sparktestsupport.modules as m; print(','.join(m.name for m in m.all_modules if m.name.startswith('pyspark')))"`
+ pyspark=`./dev/is-changed.py -m $pyspark_modules`
+ sparkr=`./dev/is-changed.py -m sparkr`
+ tpcds=`./dev/is-changed.py -m sql`
+ docker=`./dev/is-changed.py -m docker-integration-tests`
+ fi
+ # 'build', 'scala-213', and 'java-11-17' are always true for now.
+ # It does not save significant time and most of PRs trigger the build.
+ precondition="
+ {
+ \"build\": \"true\",
+ \"pyspark\": \"$pyspark\",
+ \"sparkr\": \"$sparkr\",
+ \"tpcds-1g\": \"$tpcds\",
+ \"docker-integration-tests\": \"$docker\",
+ \"scala-213\": \"true\",
+ \"java-11-17\": \"true\",
+ \"lint\" : \"true\",
+ \"k8s-integration-tests\" : \"true\",
+ }"
+ echo $precondition # For debugging
+ # Remove `\n` to avoid "Invalid format" error
+ precondition="${precondition//$'\n'/}}"
+ echo "required=$precondition" >> $GITHUB_OUTPUT
+ else
+ # This is usually set by scheduled jobs.
+ precondition='${{ inputs.jobs }}'
+ echo $precondition # For debugging
+ precondition="${precondition//$'\n'/}"
+ echo "required=$precondition" >> $GITHUB_OUTPUT
+ fi
+ - name: Generate infra image URL
+ id: infra-image-outputs
+ run: |
+ # Convert to lowercase to meet Docker repo name requirement
+ REPO_OWNER=$(echo "${{ github.repository_owner }}" | tr '[:upper:]' '[:lower:]')
+ IMG_NAME="apache-spark-ci-image:${{ inputs.branch }}-${{ github.run_id }}"
+ IMG_URL="ghcr.io/$REPO_OWNER/$IMG_NAME"
+ echo "image_url=$IMG_URL" >> $GITHUB_OUTPUT
# Build: build Spark and run the tests for specified modules.
build:
- name: "Build modules (${{ format('{0}, {1} job', needs.configure-jobs.outputs.branch, needs.configure-jobs.outputs.type) }}): ${{ matrix.modules }} ${{ matrix.comment }} (JDK ${{ matrix.java }}, ${{ matrix.hadoop }}, ${{ matrix.hive }})"
- needs: [configure-jobs, precondition]
- # Run scheduled jobs for Apache Spark only
- # Run regular jobs for commit in both Apache Spark and forked repository
- if: >-
- (github.repository == 'apache/spark' && needs.configure-jobs.outputs.type == 'scheduled')
- || (needs.configure-jobs.outputs.type == 'regular' && fromJson(needs.precondition.outputs.required).build == 'true')
+ name: "Build modules: ${{ matrix.modules }} ${{ matrix.comment }}"
+ needs: precondition
+ if: fromJson(needs.precondition.outputs.required).build == 'true'
# Ubuntu 20.04 is the latest LTS. The next LTS is 22.04.
runs-on: ubuntu-20.04
strategy:
fail-fast: false
matrix:
java:
- - ${{ needs.configure-jobs.outputs.java }}
+ - ${{ inputs.java }}
hadoop:
- - ${{ needs.configure-jobs.outputs.hadoop }}
+ - ${{ inputs.hadoop }}
hive:
- hive2.3
# TODO(SPARK-32246): We don't test 'streaming-kinesis-asl' for now.
@@ -154,7 +151,8 @@ jobs:
- >-
streaming, sql-kafka-0-10, streaming-kafka-0-10,
mllib-local, mllib,
- yarn, mesos, kubernetes, hadoop-cloud, spark-ganglia-lgpl
+ yarn, mesos, kubernetes, hadoop-cloud, spark-ganglia-lgpl,
+ connect, protobuf
# Here, we split Hive and SQL tests into some of slow ones and the rest of them.
included-tags: [""]
excluded-tags: [""]
@@ -162,27 +160,27 @@ jobs:
include:
# Hive tests
- modules: hive
- java: ${{ needs.configure-jobs.outputs.java }}
- hadoop: ${{ needs.configure-jobs.outputs.hadoop }}
+ java: ${{ inputs.java }}
+ hadoop: ${{ inputs.hadoop }}
hive: hive2.3
included-tags: org.apache.spark.tags.SlowHiveTest
comment: "- slow tests"
- modules: hive
- java: ${{ needs.configure-jobs.outputs.java }}
- hadoop: ${{ needs.configure-jobs.outputs.hadoop }}
+ java: ${{ inputs.java }}
+ hadoop: ${{ inputs.hadoop }}
hive: hive2.3
excluded-tags: org.apache.spark.tags.SlowHiveTest
comment: "- other tests"
# SQL tests
- modules: sql
- java: ${{ needs.configure-jobs.outputs.java }}
- hadoop: ${{ needs.configure-jobs.outputs.hadoop }}
+ java: ${{ inputs.java }}
+ hadoop: ${{ inputs.hadoop }}
hive: hive2.3
included-tags: org.apache.spark.tags.ExtendedSQLTest
comment: "- slow tests"
- modules: sql
- java: ${{ needs.configure-jobs.outputs.java }}
- hadoop: ${{ needs.configure-jobs.outputs.hadoop }}
+ java: ${{ inputs.java }}
+ hadoop: ${{ inputs.hadoop }}
hive: hive2.3
excluded-tags: org.apache.spark.tags.ExtendedSQLTest
comment: "- other tests"
@@ -196,22 +194,22 @@ jobs:
SPARK_LOCAL_IP: localhost
steps:
- name: Checkout Spark repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
# In order to fetch changed files
with:
fetch-depth: 0
repository: apache/spark
- ref: ${{ needs.configure-jobs.outputs.branch }}
+ ref: ${{ inputs.branch }}
- name: Sync the current branch with the latest in Apache Spark
if: github.repository != 'apache/spark'
run: |
echo "APACHE_SPARK_REF=$(git rev-parse HEAD)" >> $GITHUB_ENV
git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
- git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit"
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
# Cache local repositories. Note that GitHub Actions cache has a 2G limit.
- name: Cache Scala, SBT and Maven
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: |
build/apache-maven-*
@@ -222,18 +220,19 @@ jobs:
restore-keys: |
build-
- name: Cache Coursier local repository
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/coursier
key: ${{ matrix.java }}-${{ matrix.hadoop }}-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
restore-keys: |
${{ matrix.java }}-${{ matrix.hadoop }}-coursier-
- name: Install Java ${{ matrix.java }}
- uses: actions/setup-java@v1
+ uses: actions/setup-java@v3
with:
+ distribution: temurin
java-version: ${{ matrix.java }}
- name: Install Python 3.8
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v4
# We should install one Python that is higher then 3+ for SQL and Yarn because:
# - SQL component also has Python related tests, for example, IntegratedUDFTestUtils.
# - Yarn has a Python specific test too, for example, YarnClusterSuite.
@@ -244,11 +243,11 @@ jobs:
- name: Install Python packages (Python 3.8)
if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-'))
run: |
- python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy xmlrunner
+ python3.8 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'grpcio==1.48.1' 'protobuf==3.19.5'
python3.8 -m pip list
# Run the tests.
- name: Run tests
- env: ${{ fromJSON(needs.configure-jobs.outputs.envs) }}
+ env: ${{ fromJSON(inputs.envs) }}
run: |
# Hive "other tests" test needs larger metaspace size based on experiment.
if [[ "$MODULES_TO_TEST" == "hive" ]] && [[ "$EXCLUDED_TAGS" == "org.apache.spark.tags.SlowHiveTest" ]]; then export METASPACE_SIZE=2g; fi
@@ -256,35 +255,78 @@ jobs:
./dev/run-tests --parallelism 1 --modules "$MODULES_TO_TEST" --included-tags "$INCLUDED_TAGS" --excluded-tags "$EXCLUDED_TAGS"
- name: Upload test results to report
if: always()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
name: test-results-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }}
path: "**/target/test-reports/*.xml"
- name: Upload unit tests log files
if: failure()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
name: unit-tests-log-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }}
path: "**/target/unit-tests.log"
- pyspark:
- needs: [configure-jobs, precondition]
- # Run PySpark coverage scheduled jobs for Apache Spark only
- # Run scheduled jobs with JDK 17 in Apache Spark
- # Run regular jobs for commit in both Apache Spark and forked repository
+ infra-image:
+ name: "Base image build"
+ needs: precondition
+ # Currently, only enable docker build from cache for `master` branch jobs
if: >-
- (github.repository == 'apache/spark' && needs.configure-jobs.outputs.type == 'pyspark-coverage-scheduled')
- || (github.repository == 'apache/spark' && needs.configure-jobs.outputs.type == 'scheduled' && needs.configure-jobs.outputs.java == '17')
- || (needs.configure-jobs.outputs.type == 'regular' && fromJson(needs.precondition.outputs.required).pyspark == 'true')
- name: "Build modules (${{ format('{0}, {1} job', needs.configure-jobs.outputs.branch, needs.configure-jobs.outputs.type) }}): ${{ matrix.modules }}"
+ (fromJson(needs.precondition.outputs.required).pyspark == 'true' ||
+ fromJson(needs.precondition.outputs.required).lint == 'true' ||
+ fromJson(needs.precondition.outputs.required).sparkr == 'true') &&
+ inputs.branch == 'branch-3.4'
+ runs-on: ubuntu-latest
+ permissions:
+ packages: write
+ steps:
+ - name: Login to GitHub Container Registry
+ uses: docker/login-action@v2
+ with:
+ registry: ghcr.io
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+ - name: Checkout Spark repository
+ uses: actions/checkout@v3
+ # In order to fetch changed files
+ with:
+ fetch-depth: 0
+ repository: apache/spark
+ ref: ${{ inputs.branch }}
+ - name: Sync the current branch with the latest in Apache Spark
+ if: github.repository != 'apache/spark'
+ run: |
+ echo "APACHE_SPARK_REF=$(git rev-parse HEAD)" >> $GITHUB_ENV
+ git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@v2
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v2
+ - name: Build and push
+ id: docker_build
+ uses: docker/build-push-action@v3
+ with:
+ context: ./dev/infra/
+ push: true
+ tags: |
+ ${{ needs.precondition.outputs.image_url }}
+ # Use the infra image cache to speed up
+ cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ inputs.branch }}
+
+ pyspark:
+ needs: [precondition, infra-image]
+ # always run if pyspark == 'true', even infra-image is skip (such as non-master job)
+ if: always() && fromJson(needs.precondition.outputs.required).pyspark == 'true'
+ name: "Build modules: ${{ matrix.modules }}"
runs-on: ubuntu-20.04
container:
- image: dongjoon/apache-spark-github-action-image:20220207
+ image: ${{ needs.precondition.outputs.image_url }}
strategy:
fail-fast: false
matrix:
java:
- - ${{ needs.configure-jobs.outputs.java }}
+ - ${{ inputs.java }}
modules:
- >-
pyspark-sql, pyspark-mllib, pyspark-resource
@@ -294,34 +336,38 @@ jobs:
pyspark-pandas
- >-
pyspark-pandas-slow
+ - >-
+ pyspark-connect, pyspark-errors
env:
MODULES_TO_TEST: ${{ matrix.modules }}
- HADOOP_PROFILE: ${{ needs.configure-jobs.outputs.hadoop }}
+ HADOOP_PROFILE: ${{ inputs.hadoop }}
HIVE_PROFILE: hive2.3
GITHUB_PREV_SHA: ${{ github.event.before }}
SPARK_LOCAL_IP: localhost
SKIP_UNIDOC: true
SKIP_MIMA: true
METASPACE_SIZE: 1g
- SPARK_ANSI_SQL_MODE: ${{ inputs.ansi_enabled }}
steps:
- name: Checkout Spark repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
# In order to fetch changed files
with:
fetch-depth: 0
repository: apache/spark
- ref: branch-3.3
+ ref: ${{ inputs.branch }}
+ - name: Add GITHUB_WORKSPACE to git trust safe.directory
+ run: |
+ git config --global --add safe.directory ${GITHUB_WORKSPACE}
- name: Sync the current branch with the latest in Apache Spark
if: github.repository != 'apache/spark'
run: |
echo "APACHE_SPARK_REF=$(git rev-parse HEAD)" >> $GITHUB_ENV
git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
- git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit"
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
# Cache local repositories. Note that GitHub Actions cache has a 2G limit.
- name: Cache Scala, SBT and Maven
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: |
build/apache-maven-*
@@ -332,15 +378,16 @@ jobs:
restore-keys: |
build-
- name: Cache Coursier local repository
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/coursier
key: pyspark-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
restore-keys: |
pyspark-coursier-
- name: Install Java ${{ matrix.java }}
- uses: actions/setup-java@v1
+ uses: actions/setup-java@v3
with:
+ distribution: temurin
java-version: ${{ matrix.java }}
- name: List Python packages (Python 3.9, PyPy3)
run: |
@@ -352,12 +399,12 @@ jobs:
bash miniconda.sh -b -p $HOME/miniconda
# Run the tests.
- name: Run tests
- env: ${{ fromJSON(needs.configure-jobs.outputs.envs) }}
+ env: ${{ fromJSON(inputs.envs) }}
run: |
export PATH=$PATH:$HOME/miniconda/bin
./dev/run-tests --parallelism 1 --modules "$MODULES_TO_TEST"
- name: Upload coverage to Codecov
- if: needs.configure-jobs.outputs.type == 'pyspark-coverage-scheduled'
+ if: fromJSON(inputs.envs).PYSPARK_CODECOV == 'true'
uses: codecov/codecov-action@v2
with:
files: ./python/coverage.xml
@@ -365,51 +412,52 @@ jobs:
name: PySpark
- name: Upload test results to report
if: always()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
- name: test-results-${{ matrix.modules }}--8-${{ needs.configure-jobs.outputs.hadoop }}-hive2.3
+ name: test-results-${{ matrix.modules }}--8-${{ inputs.hadoop }}-hive2.3
path: "**/target/test-reports/*.xml"
- name: Upload unit tests log files
if: failure()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
- name: unit-tests-log-${{ matrix.modules }}--8-${{ needs.configure-jobs.outputs.hadoop }}-hive2.3
+ name: unit-tests-log-${{ matrix.modules }}--8-${{ inputs.hadoop }}-hive2.3
path: "**/target/unit-tests.log"
sparkr:
- needs: [configure-jobs, precondition]
- if: >-
- (needs.configure-jobs.outputs.type == 'regular' && fromJson(needs.precondition.outputs.required).sparkr == 'true')
- || (github.repository == 'apache/spark' && needs.configure-jobs.outputs.type == 'scheduled' && needs.configure-jobs.outputs.java == '17')
+ needs: [precondition, infra-image]
+ # always run if sparkr == 'true', even infra-image is skip (such as non-master job)
+ if: always() && fromJson(needs.precondition.outputs.required).sparkr == 'true'
name: "Build modules: sparkr"
runs-on: ubuntu-20.04
container:
- image: dongjoon/apache-spark-github-action-image:20220207
+ image: ${{ needs.precondition.outputs.image_url }}
env:
- HADOOP_PROFILE: ${{ needs.configure-jobs.outputs.hadoop }}
+ HADOOP_PROFILE: ${{ inputs.hadoop }}
HIVE_PROFILE: hive2.3
GITHUB_PREV_SHA: ${{ github.event.before }}
SPARK_LOCAL_IP: localhost
SKIP_MIMA: true
- SPARK_ANSI_SQL_MODE: ${{ inputs.ansi_enabled }}
steps:
- name: Checkout Spark repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
# In order to fetch changed files
with:
fetch-depth: 0
repository: apache/spark
- ref: branch-3.3
+ ref: ${{ inputs.branch }}
+ - name: Add GITHUB_WORKSPACE to git trust safe.directory
+ run: |
+ git config --global --add safe.directory ${GITHUB_WORKSPACE}
- name: Sync the current branch with the latest in Apache Spark
if: github.repository != 'apache/spark'
run: |
echo "APACHE_SPARK_REF=$(git rev-parse HEAD)" >> $GITHUB_ENV
git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
- git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit"
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
# Cache local repositories. Note that GitHub Actions cache has a 2G limit.
- name: Cache Scala, SBT and Maven
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: |
build/apache-maven-*
@@ -420,17 +468,19 @@ jobs:
restore-keys: |
build-
- name: Cache Coursier local repository
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/coursier
key: sparkr-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
restore-keys: |
sparkr-coursier-
- - name: Install Java ${{ needs.configure-jobs.outputs.java }}
- uses: actions/setup-java@v1
+ - name: Install Java ${{ inputs.java }}
+ uses: actions/setup-java@v3
with:
- java-version: ${{ needs.configure-jobs.outputs.java }}
+ distribution: temurin
+ java-version: ${{ inputs.java }}
- name: Run tests
+ env: ${{ fromJSON(inputs.envs) }}
run: |
# The followings are also used by `r-lib/actions/setup-r` to avoid
# R issues at docker environment
@@ -439,15 +489,16 @@ jobs:
./dev/run-tests --parallelism 1 --modules sparkr
- name: Upload test results to report
if: always()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
- name: test-results-sparkr--8-${{ needs.configure-jobs.outputs.hadoop }}-hive2.3
+ name: test-results-sparkr--8-${{ inputs.hadoop }}-hive2.3
path: "**/target/test-reports/*.xml"
# Static analysis, and documentation build
lint:
- needs: configure-jobs
- if: needs.configure-jobs.outputs.type == 'regular'
+ needs: [precondition, infra-image]
+ # always run if lint == 'true', even infra-image is skip (such as non-master job)
+ if: always() && fromJson(needs.precondition.outputs.required).lint == 'true'
name: Linters, licenses, dependencies and documentation generation
runs-on: ubuntu-20.04
env:
@@ -455,24 +506,29 @@ jobs:
LANG: C.UTF-8
PYSPARK_DRIVER_PYTHON: python3.9
PYSPARK_PYTHON: python3.9
+ GITHUB_PREV_SHA: ${{ github.event.before }}
container:
- image: dongjoon/apache-spark-github-action-image:20220207
+ image: ${{ needs.precondition.outputs.image_url }}
steps:
- name: Checkout Spark repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
with:
fetch-depth: 0
repository: apache/spark
- ref: branch-3.3
+ ref: ${{ inputs.branch }}
+ - name: Add GITHUB_WORKSPACE to git trust safe.directory
+ run: |
+ git config --global --add safe.directory ${GITHUB_WORKSPACE}
- name: Sync the current branch with the latest in Apache Spark
if: github.repository != 'apache/spark'
run: |
+ echo "APACHE_SPARK_REF=$(git rev-parse HEAD)" >> $GITHUB_ENV
git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
- git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit"
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
# Cache local repositories. Note that GitHub Actions cache has a 2G limit.
- name: Cache Scala, SBT and Maven
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: |
build/apache-maven-*
@@ -483,27 +539,60 @@ jobs:
restore-keys: |
build-
- name: Cache Coursier local repository
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/coursier
key: docs-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
restore-keys: |
docs-coursier-
- name: Cache Maven local repository
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.m2/repository
key: docs-maven-${{ hashFiles('**/pom.xml') }}
restore-keys: |
docs-maven-
+ - name: Install Java 8
+ uses: actions/setup-java@v3
+ with:
+ distribution: temurin
+ java-version: 8
+ - name: License test
+ run: ./dev/check-license
+ - name: Dependencies test
+ run: ./dev/test-dependencies.sh
+ - name: Scala linter
+ run: ./dev/lint-scala
+ - name: Java linter
+ run: ./dev/lint-java
+ - name: Spark connect jvm client mima check
+ if: inputs.branch != 'branch-3.2' && inputs.branch != 'branch-3.3'
+ run: ./dev/connect-jvm-client-mima-check
- name: Install Python linter dependencies
run: |
# TODO(SPARK-32407): Sphinx 3.1+ does not correctly index nested classes.
# See also https://github.com/sphinx-doc/sphinx/issues/7551.
# Jinja2 3.0.0+ causes error when building with Sphinx.
# See also https://issues.apache.org/jira/browse/SPARK-35375.
- python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==21.12b0'
- python3.9 -m pip install 'pandas-stubs==1.2.0.53'
+ python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0'
+ python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.48.1' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0'
+ - name: Python linter
+ run: PYTHON_EXECUTABLE=python3.9 ./dev/lint-python
+ - name: Install dependencies for Python code generation check
+ run: |
+ # See more in "Installation" https://docs.buf.build/installation#tarball
+ curl -LO https://github.com/bufbuild/buf/releases/download/v1.15.1/buf-Linux-x86_64.tar.gz
+ mkdir -p $HOME/buf
+ tar -xvzf buf-Linux-x86_64.tar.gz -C $HOME/buf --strip-components 1
+ python3.9 -m pip install 'protobuf==3.19.5' 'mypy-protobuf==3.3.0'
+ - name: Python code generation check
+ run: if test -f ./dev/connect-check-protos.py; then PATH=$PATH:$HOME/buf/bin PYTHON_EXECUTABLE=python3.9 ./dev/connect-check-protos.py; fi
+ - name: Install JavaScript linter dependencies
+ run: |
+ apt update
+ apt-get install -y nodejs npm
+ - name: JS linter
+ run: ./dev/lint-js
- name: Install R linter dependencies and SparkR
run: |
apt update
@@ -513,10 +602,6 @@ jobs:
Rscript -e "install.packages(c('devtools'), repos='https://cloud.r-project.org/')"
Rscript -e "devtools::install_version('lintr', version='2.0.1', repos='https://cloud.r-project.org')"
./R/install-dev.sh
- - name: Instll JavaScript linter dependencies
- run: |
- apt update
- apt-get install -y nodejs npm
- name: Install dependencies for documentation generation
run: |
# pandoc is required to generate PySpark APIs as well in nbsphinx.
@@ -527,9 +612,9 @@ jobs:
# See also https://issues.apache.org/jira/browse/SPARK-35375.
# Pin the MarkupSafe to 2.0.1 to resolve the CI error.
# See also https://issues.apache.org/jira/browse/SPARK-38279.
- python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme ipython nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0'
+ python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0'
python3.9 -m pip install ipython_genutils # See SPARK-38517
- python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8'
+ python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8'
python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421
apt-get update -y
apt-get install -y ruby ruby-dev
@@ -539,32 +624,22 @@ jobs:
gem install bundler
cd docs
bundle install
- - name: Install Java 8
- uses: actions/setup-java@v1
- with:
- java-version: 8
- - name: Scala linter
- run: ./dev/lint-scala
- - name: Java linter
- run: ./dev/lint-java
- - name: Python linter
- run: PYTHON_EXECUTABLE=python3.9 ./dev/lint-python
- name: R linter
run: ./dev/lint-r
- - name: JS linter
- run: ./dev/lint-js
- - name: License test
- run: ./dev/check-license
- - name: Dependencies test
- run: ./dev/test-dependencies.sh
- name: Run documentation build
run: |
+ if [ -f "./dev/is-changed.py" ]; then
+ # Skip PySpark and SparkR docs while keeping Scala/Java/SQL docs
+ pyspark_modules=`cd dev && python3.9 -c "import sparktestsupport.modules as m; print(','.join(m.name for m in m.all_modules if m.name.startswith('pyspark')))"`
+ if [ `./dev/is-changed.py -m $pyspark_modules` = false ]; then export SKIP_PYTHONDOC=1; fi
+ if [ `./dev/is-changed.py -m sparkr` = false ]; then export SKIP_RDOC=1; fi
+ fi
cd docs
bundle exec jekyll build
java-11-17:
- needs: [configure-jobs, precondition]
- if: needs.configure-jobs.outputs.type == 'regular' && fromJson(needs.precondition.outputs.required).build == 'true'
+ needs: precondition
+ if: fromJson(needs.precondition.outputs.required).java-11-17 == 'true'
name: Java ${{ matrix.java }} build with Maven
strategy:
fail-fast: false
@@ -575,19 +650,19 @@ jobs:
runs-on: ubuntu-20.04
steps:
- name: Checkout Spark repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
with:
fetch-depth: 0
repository: apache/spark
- ref: branch-3.3
+ ref: ${{ inputs.branch }}
- name: Sync the current branch with the latest in Apache Spark
if: github.repository != 'apache/spark'
run: |
git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
- git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit"
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
- name: Cache Scala, SBT and Maven
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: |
build/apache-maven-*
@@ -598,15 +673,16 @@ jobs:
restore-keys: |
build-
- name: Cache Maven local repository
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.m2/repository
key: java${{ matrix.java }}-maven-${{ hashFiles('**/pom.xml') }}
restore-keys: |
java${{ matrix.java }}-maven-
- name: Install Java ${{ matrix.java }}
- uses: actions/setup-java@v1
+ uses: actions/setup-java@v3
with:
+ distribution: temurin
java-version: ${{ matrix.java }}
- name: Build with Maven
run: |
@@ -618,25 +694,25 @@ jobs:
rm -rf ~/.m2/repository/org/apache/spark
scala-213:
- needs: [configure-jobs, precondition]
- if: needs.configure-jobs.outputs.type == 'regular' && fromJson(needs.precondition.outputs.required).build == 'true'
+ needs: precondition
+ if: fromJson(needs.precondition.outputs.required).scala-213 == 'true'
name: Scala 2.13 build with SBT
runs-on: ubuntu-20.04
steps:
- name: Checkout Spark repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
with:
fetch-depth: 0
repository: apache/spark
- ref: branch-3.3
+ ref: ${{ inputs.branch }}
- name: Sync the current branch with the latest in Apache Spark
if: github.repository != 'apache/spark'
run: |
git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
- git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit"
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
- name: Cache Scala, SBT and Maven
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: |
build/apache-maven-*
@@ -647,44 +723,45 @@ jobs:
restore-keys: |
build-
- name: Cache Coursier local repository
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/coursier
key: scala-213-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
restore-keys: |
scala-213-coursier-
- name: Install Java 8
- uses: actions/setup-java@v1
+ uses: actions/setup-java@v3
with:
+ distribution: temurin
java-version: 8
- name: Build with SBT
run: |
./dev/change-scala-version.sh 2.13
- ./build/sbt -Pyarn -Pmesos -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pkinesis-asl -Pdocker-integration-tests -Pkubernetes-integration-tests -Pspark-ganglia-lgpl -Pscala-2.13 compile test:compile
+ ./build/sbt -Pyarn -Pmesos -Pkubernetes -Pvolcano -Phive -Phive-thriftserver -Phadoop-cloud -Pkinesis-asl -Pdocker-integration-tests -Pkubernetes-integration-tests -Pspark-ganglia-lgpl -Pscala-2.13 compile Test/compile
+ # Any TPC-DS related updates on this job need to be applied to tpcds-1g-gen job of benchmark.yml as well
tpcds-1g:
- needs: [configure-jobs, precondition]
- if: needs.configure-jobs.outputs.type == 'regular' && fromJson(needs.precondition.outputs.required).tpcds == 'true'
+ needs: precondition
+ if: fromJson(needs.precondition.outputs.required).tpcds-1g == 'true'
name: Run TPC-DS queries with SF=1
runs-on: ubuntu-20.04
env:
SPARK_LOCAL_IP: localhost
- SPARK_ANSI_SQL_MODE: ${{ inputs.ansi_enabled }}
steps:
- name: Checkout Spark repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
with:
fetch-depth: 0
repository: apache/spark
- ref: branch-3.3
+ ref: ${{ inputs.branch }}
- name: Sync the current branch with the latest in Apache Spark
if: github.repository != 'apache/spark'
run: |
git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
- git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit"
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
- name: Cache Scala, SBT and Maven
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: |
build/apache-maven-*
@@ -695,25 +772,26 @@ jobs:
restore-keys: |
build-
- name: Cache Coursier local repository
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/coursier
key: tpcds-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
restore-keys: |
tpcds-coursier-
- name: Install Java 8
- uses: actions/setup-java@v1
+ uses: actions/setup-java@v3
with:
+ distribution: temurin
java-version: 8
- name: Cache TPC-DS generated data
id: cache-tpcds-sf-1
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ./tpcds-sf-1
key: tpcds-${{ hashFiles('.github/workflows/build_and_test.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }}
- name: Checkout tpcds-kit repository
if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true'
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
with:
repository: databricks/tpcds-kit
ref: 2a5078a782192ddb6efbcead8de9973d6ab4f069
@@ -723,11 +801,12 @@ jobs:
run: cd tpcds-kit/tools && make OS=LINUX
- name: Generate TPC-DS (SF=1) table data
if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true'
- run: build/sbt "sql/test:runMain org.apache.spark.sql.GenTPCDSData --dsdgenDir `pwd`/tpcds-kit/tools --location `pwd`/tpcds-sf-1 --scaleFactor 1 --numPartitions 1 --overwrite"
+ run: build/sbt "sql/Test/runMain org.apache.spark.sql.GenTPCDSData --dsdgenDir `pwd`/tpcds-kit/tools --location `pwd`/tpcds-sf-1 --scaleFactor 1 --numPartitions 1 --overwrite"
- name: Run TPC-DS queries (Sort merge join)
run: |
SPARK_TPCDS_DATA=`pwd`/tpcds-sf-1 build/sbt "sql/testOnly org.apache.spark.sql.TPCDSQueryTestSuite"
env:
+ SPARK_ANSI_SQL_MODE: ${{ fromJSON(inputs.envs).SPARK_ANSI_SQL_MODE }}
SPARK_TPCDS_JOIN_CONF: |
spark.sql.autoBroadcastJoinThreshold=-1
spark.sql.join.preferSortMergeJoin=true
@@ -735,56 +814,58 @@ jobs:
run: |
SPARK_TPCDS_DATA=`pwd`/tpcds-sf-1 build/sbt "sql/testOnly org.apache.spark.sql.TPCDSQueryTestSuite"
env:
+ SPARK_ANSI_SQL_MODE: ${{ fromJSON(inputs.envs).SPARK_ANSI_SQL_MODE }}
SPARK_TPCDS_JOIN_CONF: |
spark.sql.autoBroadcastJoinThreshold=10485760
- name: Run TPC-DS queries (Shuffled hash join)
run: |
SPARK_TPCDS_DATA=`pwd`/tpcds-sf-1 build/sbt "sql/testOnly org.apache.spark.sql.TPCDSQueryTestSuite"
env:
+ SPARK_ANSI_SQL_MODE: ${{ fromJSON(inputs.envs).SPARK_ANSI_SQL_MODE }}
SPARK_TPCDS_JOIN_CONF: |
spark.sql.autoBroadcastJoinThreshold=-1
spark.sql.join.forceApplyShuffledHashJoin=true
- name: Upload test results to report
if: always()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
- name: test-results-tpcds--8-${{ needs.configure-jobs.outputs.hadoop }}-hive2.3
+ name: test-results-tpcds--8-${{ inputs.hadoop }}-hive2.3
path: "**/target/test-reports/*.xml"
- name: Upload unit tests log files
if: failure()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
- name: unit-tests-log-tpcds--8-${{ needs.configure-jobs.outputs.hadoop }}-hive2.3
+ name: unit-tests-log-tpcds--8-${{ inputs.hadoop }}-hive2.3
path: "**/target/unit-tests.log"
docker-integration-tests:
- needs: [configure-jobs, precondition]
- if: needs.configure-jobs.outputs.type == 'regular' && fromJson(needs.precondition.outputs.required).docker == 'true'
+ needs: precondition
+ if: fromJson(needs.precondition.outputs.required).docker-integration-tests == 'true'
name: Run Docker integration tests
runs-on: ubuntu-20.04
env:
- HADOOP_PROFILE: ${{ needs.configure-jobs.outputs.hadoop }}
+ HADOOP_PROFILE: ${{ inputs.hadoop }}
HIVE_PROFILE: hive2.3
GITHUB_PREV_SHA: ${{ github.event.before }}
SPARK_LOCAL_IP: localhost
- ORACLE_DOCKER_IMAGE_NAME: gvenzl/oracle-xe:18.4.0
+ ORACLE_DOCKER_IMAGE_NAME: gvenzl/oracle-xe:21.3.0
SKIP_MIMA: true
steps:
- name: Checkout Spark repository
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
with:
fetch-depth: 0
repository: apache/spark
- ref: branch-3.3
+ ref: ${{ inputs.branch }}
- name: Sync the current branch with the latest in Apache Spark
if: github.repository != 'apache/spark'
run: |
echo "APACHE_SPARK_REF=$(git rev-parse HEAD)" >> $GITHUB_ENV
git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
- git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit"
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
- name: Cache Scala, SBT and Maven
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: |
build/apache-maven-*
@@ -795,28 +876,100 @@ jobs:
restore-keys: |
build-
- name: Cache Coursier local repository
- uses: actions/cache@v2
+ uses: actions/cache@v3
with:
path: ~/.cache/coursier
key: docker-integration-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
restore-keys: |
docker-integration-coursier-
- name: Install Java 8
- uses: actions/setup-java@v1
+ uses: actions/setup-java@v3
with:
+ distribution: temurin
java-version: 8
- name: Run tests
run: |
./dev/run-tests --parallelism 1 --modules docker-integration-tests --included-tags org.apache.spark.tags.DockerTest
- name: Upload test results to report
if: always()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
- name: test-results-docker-integration--8-${{ needs.configure-jobs.outputs.hadoop }}-hive2.3
+ name: test-results-docker-integration--8-${{ inputs.hadoop }}-hive2.3
path: "**/target/test-reports/*.xml"
- name: Upload unit tests log files
if: failure()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
- name: unit-tests-log-docker-integration--8-${{ needs.configure-jobs.outputs.hadoop }}-hive2.3
+ name: unit-tests-log-docker-integration--8-${{ inputs.hadoop }}-hive2.3
path: "**/target/unit-tests.log"
+
+ k8s-integration-tests:
+ needs: precondition
+ if: fromJson(needs.precondition.outputs.required).k8s-integration-tests == 'true'
+ name: Run Spark on Kubernetes Integration test
+ runs-on: ubuntu-20.04
+ steps:
+ - name: Checkout Spark repository
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+ repository: apache/spark
+ ref: ${{ inputs.branch }}
+ - name: Sync the current branch with the latest in Apache Spark
+ if: github.repository != 'apache/spark'
+ run: |
+ echo "APACHE_SPARK_REF=$(git rev-parse HEAD)" >> $GITHUB_ENV
+ git fetch https://github.com/$GITHUB_REPOSITORY.git ${GITHUB_REF#refs/heads/}
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD
+ git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty
+ - name: Cache Scala, SBT and Maven
+ uses: actions/cache@v3
+ with:
+ path: |
+ build/apache-maven-*
+ build/scala-*
+ build/*.jar
+ ~/.sbt
+ key: build-${{ hashFiles('**/pom.xml', 'project/build.properties', 'build/mvn', 'build/sbt', 'build/sbt-launch-lib.bash', 'build/spark-build-info') }}
+ restore-keys: |
+ build-
+ - name: Cache Coursier local repository
+ uses: actions/cache@v3
+ with:
+ path: ~/.cache/coursier
+ key: k8s-integration-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
+ restore-keys: |
+ k8s-integration-coursier-
+ - name: Install Java ${{ inputs.java }}
+ uses: actions/setup-java@v3
+ with:
+ distribution: temurin
+ java-version: ${{ inputs.java }}
+ - name: start minikube
+ run: |
+ # See more in "Installation" https://minikube.sigs.k8s.io/docs/start/
+ curl -LO https://storage.googleapis.com/minikube/releases/latest/minikube-linux-amd64
+ sudo install minikube-linux-amd64 /usr/local/bin/minikube
+ # Github Action limit cpu:2, memory: 6947MB, limit to 2U6G for better resource statistic
+ minikube start --cpus 2 --memory 6144
+ - name: Print K8S pods and nodes info
+ run: |
+ kubectl get pods -A
+ kubectl describe node
+ - name: Run Spark on K8S integration test (With driver cpu 0.5, executor cpu 0.2 limited)
+ run: |
+ # Prepare PV test
+ PVC_TMP_DIR=$(mktemp -d)
+ export PVC_TESTS_HOST_PATH=$PVC_TMP_DIR
+ export PVC_TESTS_VM_PATH=$PVC_TMP_DIR
+ minikube mount ${PVC_TESTS_HOST_PATH}:${PVC_TESTS_VM_PATH} --gid=0 --uid=185 &
+ kubectl create clusterrolebinding serviceaccounts-cluster-admin --clusterrole=cluster-admin --group=system:serviceaccounts || true
+ kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.7.0/installer/volcano-development.yaml || true
+ eval $(minikube docker-env)
+ build/sbt -Psparkr -Pkubernetes -Pvolcano -Pkubernetes-integration-tests -Dspark.kubernetes.test.driverRequestCores=0.5 -Dspark.kubernetes.test.executorRequestCores=0.2 -Dspark.kubernetes.test.volcanoMaxConcurrencyJobNum=1 -Dtest.exclude.tags=local "kubernetes-integration-tests/test"
+ - name: Upload Spark on K8S integration tests log files
+ if: failure()
+ uses: actions/upload-artifact@v3
+ with:
+ name: spark-on-kubernetes-it-log
+ path: "**/target/integration-tests.log"
diff --git a/.github/workflows/build_and_test_ansi.yml b/.github/workflows/build_and_test_ansi.yml
deleted file mode 100644
index 3b8e44ff80ec3..0000000000000
--- a/.github/workflows/build_and_test_ansi.yml
+++ /dev/null
@@ -1,34 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-name: "Build and test (ANSI)"
-
-on:
- push:
- branches:
- - branch-3.3
-
-jobs:
- call-build-and-test:
- name: Call main build
- uses: ./.github/workflows/build_and_test.yml
- if: github.repository == 'apache/spark'
- with:
- ansi_enabled: true
-
diff --git a/.github/workflows/build_ansi.yml b/.github/workflows/build_ansi.yml
new file mode 100644
index 0000000000000..e67a9262fcd70
--- /dev/null
+++ b/.github/workflows/build_ansi.yml
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build / ANSI (master, Hadoop 3, JDK 8, Scala 2.12)"
+
+on:
+ schedule:
+ - cron: '0 1 * * *'
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 8
+ branch: master
+ hadoop: hadoop3
+ envs: >-
+ {
+ "SPARK_ANSI_SQL_MODE": "true",
+ }
+ jobs: >-
+ {
+ "build": "true",
+ "pyspark": "true",
+ "sparkr": "true",
+ "tpcds-1g": "true",
+ "docker-integration-tests": "true"
+ }
diff --git a/.github/workflows/build_branch32.yml b/.github/workflows/build_branch32.yml
new file mode 100644
index 0000000000000..723db45ca3755
--- /dev/null
+++ b/.github/workflows/build_branch32.yml
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build (branch-3.2, Scala 2.13, Hadoop 3, JDK 8)"
+
+on:
+ schedule:
+ - cron: '0 4 * * *'
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 8
+ branch: branch-3.2
+ hadoop: hadoop3.2
+ envs: >-
+ {
+ "SCALA_PROFILE": "scala2.13"
+ }
+ # TODO(SPARK-39712): Reenable "sparkr": "true"
+ # TODO(SPARK-39685): Reenable "lint": "true"
+ # TODO(SPARK-39681): Reenable "pyspark": "true"
+ # TODO(SPARK-39682): Reenable "docker-integration-tests": "true"
+ jobs: >-
+ {
+ "build": "true",
+ "tpcds-1g": "true"
+ }
diff --git a/.github/workflows/build_branch33.yml b/.github/workflows/build_branch33.yml
new file mode 100644
index 0000000000000..7ceafceb7180d
--- /dev/null
+++ b/.github/workflows/build_branch33.yml
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build (branch-3.3, Scala 2.13, Hadoop 3, JDK 8)"
+
+on:
+ schedule:
+ - cron: '0 7 * * *'
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 8
+ branch: branch-3.3
+ hadoop: hadoop3
+ envs: >-
+ {
+ "SCALA_PROFILE": "scala2.13"
+ }
+ jobs: >-
+ {
+ "build": "true",
+ "pyspark": "true",
+ "sparkr": "true",
+ "tpcds-1g": "true",
+ "docker-integration-tests": "true",
+ "lint" : "true"
+ }
diff --git a/.github/workflows/build_coverage.yml b/.github/workflows/build_coverage.yml
new file mode 100644
index 0000000000000..aa210f0031866
--- /dev/null
+++ b/.github/workflows/build_coverage.yml
@@ -0,0 +1,44 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build / Coverage (master, Scala 2.12, Hadoop 3, JDK 8)"
+
+on:
+ schedule:
+ - cron: '0 10 * * *'
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 8
+ branch: master
+ hadoop: hadoop3
+ envs: >-
+ {
+ "PYSPARK_CODECOV": "true"
+ }
+ jobs: >-
+ {
+ "pyspark": "true"
+ }
diff --git a/.github/workflows/build_hadoop2.yml b/.github/workflows/build_hadoop2.yml
new file mode 100644
index 0000000000000..9716d568be8e0
--- /dev/null
+++ b/.github/workflows/build_hadoop2.yml
@@ -0,0 +1,44 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build (master, Scala 2.12, Hadoop 2, JDK 8)"
+
+on:
+ schedule:
+ - cron: '0 13 * * *'
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 8
+ branch: master
+ hadoop: hadoop2
+ # TODO(SPARK-39684): Reenable "docker-integration-tests": "true"
+ jobs: >-
+ {
+ "build": "true",
+ "pyspark": "true",
+ "sparkr": "true",
+ "tpcds-1g": "true"
+ }
diff --git a/.github/workflows/build_infra_images_cache.yml b/.github/workflows/build_infra_images_cache.yml
new file mode 100644
index 0000000000000..b8aae945599de
--- /dev/null
+++ b/.github/workflows/build_infra_images_cache.yml
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: Build / Cache base image
+
+on:
+ # Run jobs when a commit is merged
+ push:
+ branches:
+ - 'master'
+ - 'branch-*'
+ paths:
+ - 'dev/infra/Dockerfile'
+ - '.github/workflows/build_infra_images_cache.yml'
+ # Create infra image when cutting down branches/tags
+ create:
+jobs:
+ main:
+ if: github.repository == 'apache/spark'
+ runs-on: ubuntu-latest
+ permissions:
+ packages: write
+ steps:
+ - name: Checkout Spark repository
+ uses: actions/checkout@v3
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@v2
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v2
+ - name: Login to DockerHub
+ uses: docker/login-action@v2
+ with:
+ registry: ghcr.io
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+ - name: Build and push
+ id: docker_build
+ uses: docker/build-push-action@v3
+ with:
+ context: ./dev/infra/
+ push: true
+ tags: ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ github.ref_name }}-static
+ cache-from: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ github.ref_name }}
+ cache-to: type=registry,ref=ghcr.io/apache/spark/apache-spark-github-action-image-cache:${{ github.ref_name }},mode=max
+ - name: Image digest
+ run: echo ${{ steps.docker_build.outputs.digest }}
diff --git a/.github/workflows/build_java11.yml b/.github/workflows/build_java11.yml
new file mode 100644
index 0000000000000..bf7b2edb45ff3
--- /dev/null
+++ b/.github/workflows/build_java11.yml
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build (master, Scala 2.12, Hadoop 3, JDK 11)"
+
+on:
+ schedule:
+ - cron: '0 16 * * *'
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 11
+ branch: master
+ hadoop: hadoop3
+ envs: >-
+ {
+ "SKIP_MIMA": "true",
+ "SKIP_UNIDOC": "true"
+ }
+ jobs: >-
+ {
+ "build": "true",
+ "pyspark": "true",
+ "sparkr": "true",
+ "tpcds-1g": "true",
+ "docker-integration-tests": "true"
+ }
diff --git a/.github/workflows/build_java17.yml b/.github/workflows/build_java17.yml
new file mode 100644
index 0000000000000..9465e5ea0e317
--- /dev/null
+++ b/.github/workflows/build_java17.yml
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build (master, Scala 2.12, Hadoop 3, JDK 17)"
+
+on:
+ schedule:
+ - cron: '0 22 * * *'
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 17
+ branch: master
+ hadoop: hadoop3
+ envs: >-
+ {
+ "SKIP_MIMA": "true",
+ "SKIP_UNIDOC": "true"
+ }
+ jobs: >-
+ {
+ "build": "true",
+ "pyspark": "true",
+ "sparkr": "true",
+ "tpcds-1g": "true",
+ "docker-integration-tests": "true"
+ }
diff --git a/.github/workflows/build_main.yml b/.github/workflows/build_main.yml
new file mode 100644
index 0000000000000..1ac6c87b7d041
--- /dev/null
+++ b/.github/workflows/build_main.yml
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+<<<<<<<< HEAD:.github/workflows/build_and_test_ansi.yml
+name: "Build and test (ANSI)"
+========
+name: "Build"
+>>>>>>>> 17a8e67a6a03fd5a33f4ed078f8325665a0635aa:.github/workflows/build_main.yml
+
+on:
+ push:
+ branches:
+<<<<<<<< HEAD:.github/workflows/build_and_test_ansi.yml
+ - branch-3.3
+
+jobs:
+ call-build-and-test:
+ name: Call main build
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ ansi_enabled: true
+
+========
+ - '**'
+
+jobs:
+ call-build-and-test:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+>>>>>>>> 17a8e67a6a03fd5a33f4ed078f8325665a0635aa:.github/workflows/build_main.yml
diff --git a/.github/workflows/build_rockdb_as_ui_backend.yml b/.github/workflows/build_rockdb_as_ui_backend.yml
new file mode 100644
index 0000000000000..04e0e7c2e1073
--- /dev/null
+++ b/.github/workflows/build_rockdb_as_ui_backend.yml
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build / RocksDB as UI Backend (master, Hadoop 3, JDK 8, Scala 2.12)"
+
+on:
+ schedule:
+ - cron: '0 6 * * *'
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 8
+ branch: master
+ hadoop: hadoop3
+ envs: >-
+ {
+ "LIVE_UI_LOCAL_STORE_DIR": "/tmp/kvStore",
+ }
+ jobs: >-
+ {
+ "build": "true",
+ "pyspark": "true",
+ "sparkr": "true",
+ "tpcds-1g": "true",
+ "docker-integration-tests": "true"
+ }
diff --git a/.github/workflows/build_scala213.yml b/.github/workflows/build_scala213.yml
new file mode 100644
index 0000000000000..cae0981ee1e8a
--- /dev/null
+++ b/.github/workflows/build_scala213.yml
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: "Build (master, Scala 2.13, Hadoop 3, JDK 8)"
+
+on:
+ schedule:
+ - cron: '0 19 * * *'
+
+jobs:
+ run-build:
+ permissions:
+ packages: write
+ name: Run
+ uses: ./.github/workflows/build_and_test.yml
+ if: github.repository == 'apache/spark'
+ with:
+ java: 8
+ branch: master
+ hadoop: hadoop3
+ envs: >-
+ {
+ "SCALA_PROFILE": "scala2.13"
+ }
+ jobs: >-
+ {
+ "build": "true",
+ "pyspark": "true",
+ "sparkr": "true",
+ "tpcds-1g": "true",
+ "docker-integration-tests": "true",
+ "lint" : "true"
+ }
diff --git a/.github/workflows/cancel_duplicate_workflow_runs.yml b/.github/workflows/cancel_duplicate_workflow_runs.yml
index 525c7e7972c2a..d41ca31190d94 100644
--- a/.github/workflows/cancel_duplicate_workflow_runs.yml
+++ b/.github/workflows/cancel_duplicate_workflow_runs.yml
@@ -21,7 +21,7 @@ name: Cancelling Duplicates
on:
workflow_run:
workflows:
- - 'Build and test'
+ - 'Build'
types: ['requested']
jobs:
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
index 88d17bf34d504..c6b6e65bc9fec 100644
--- a/.github/workflows/labeler.yml
+++ b/.github/workflows/labeler.yml
@@ -30,6 +30,9 @@ jobs:
label:
name: Label pull requests
runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ pull-requests: write
steps:
# In order to get back the negated matches like in the old config,
# we need the actinons/labeler concept of `all` and `any` which matches
@@ -44,7 +47,7 @@ jobs:
#
# However, these are not in a published release and the current `main` branch
# has some issues upon testing.
- - uses: actions/labeler@5f867a63be70efff62b767459b009290364495eb # pin@2.2.0
+ - uses: actions/labeler@v4
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"
sync-labels: true
diff --git a/.github/workflows/notify_test_workflow.yml b/.github/workflows/notify_test_workflow.yml
index eb0da84a797c3..6fb776d708346 100644
--- a/.github/workflows/notify_test_workflow.yml
+++ b/.github/workflows/notify_test_workflow.yml
@@ -31,9 +31,12 @@ jobs:
notify:
name: Notify test workflow
runs-on: ubuntu-20.04
+ permissions:
+ actions: read
+ checks: write
steps:
- name: "Notify test workflow"
- uses: actions/github-script@f05a81df23035049204b043b50c3322045ce7eb3 # pin@v3
+ uses: actions/github-script@v6
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@@ -46,7 +49,7 @@ jobs:
const params = {
owner: context.payload.pull_request.head.repo.owner.login,
repo: context.payload.pull_request.head.repo.name,
- id: 'build_and_test.yml',
+ id: 'build_main.yml',
branch: context.payload.pull_request.head.ref,
}
const check_run_params = {
@@ -69,7 +72,7 @@ jobs:
// Assume that runs were not found.
}
- const name = 'Build and test'
+ const name = 'Build'
const head_sha = context.payload.pull_request.head.sha
let status = 'queued'
@@ -77,7 +80,7 @@ jobs:
status = 'completed'
const conclusion = 'action_required'
- github.checks.create({
+ github.rest.checks.create({
owner: context.repo.owner,
repo: context.repo.repo,
name: name,
@@ -113,7 +116,7 @@ jobs:
// Here we get check run ID to provide Check run view instead of Actions view, see also SPARK-37879.
const check_runs = await github.request(check_run_endpoint, check_run_params)
- const check_run_head = check_runs.data.check_runs.filter(r => r.name === "Configure jobs")[0]
+ const check_run_head = check_runs.data.check_runs.filter(r => r.name === "Run / Check changes")[0]
if (check_run_head.head_sha != context.payload.pull_request.head.sha) {
throw new Error('There was a new unsynced commit pushed. Please retrigger the workflow.');
@@ -129,7 +132,7 @@ jobs:
+ '/actions/runs/'
+ run_id
- github.checks.create({
+ github.rest.checks.create({
owner: context.repo.owner,
repo: context.repo.repo,
name: name,
diff --git a/.github/workflows/publish_snapshot.yml b/.github/workflows/publish_snapshot.yml
index bd75e26108658..f0a8ad5ef6a72 100644
--- a/.github/workflows/publish_snapshot.yml
+++ b/.github/workflows/publish_snapshot.yml
@@ -32,23 +32,24 @@ jobs:
matrix:
branch:
- master
+ - branch-3.3
- branch-3.2
- - branch-3.1
steps:
- name: Checkout Spark repository
- uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # pin@master
+ uses: actions/checkout@v3
with:
ref: ${{ matrix.branch }}
- name: Cache Maven local repository
- uses: actions/cache@c64c572235d810460d0d6876e9c705ad5002b353 # pin@v2
+ uses: actions/cache@v3
with:
path: ~/.m2/repository
key: snapshot-maven-${{ hashFiles('**/pom.xml') }}
restore-keys: |
snapshot-maven-
- name: Install Java 8
- uses: actions/setup-java@d202f5dbf7256730fb690ec59f6381650114feb2 # pin@v1
+ uses: actions/setup-java@v3
with:
+ distribution: temurin
java-version: 8
- name: Publish snapshot
env:
diff --git a/.github/workflows/test_report.yml b/.github/workflows/test_report.yml
index a3f09c06ed989..c6225e6a1abe5 100644
--- a/.github/workflows/test_report.yml
+++ b/.github/workflows/test_report.yml
@@ -20,12 +20,13 @@
name: Report test results
on:
workflow_run:
- workflows: ["Build and test", "Build and test (ANSI)"]
+ workflows: ["Build"]
types:
- completed
jobs:
test_report:
+ if: github.event.workflow_run.conclusion != 'skipped'
runs-on: ubuntu-latest
steps:
- name: Download test results to report
diff --git a/.github/workflows/update_build_status.yml b/.github/workflows/update_build_status.yml
index 671487adbfe05..05cf4914a25ca 100644
--- a/.github/workflows/update_build_status.yml
+++ b/.github/workflows/update_build_status.yml
@@ -27,9 +27,12 @@ jobs:
update:
name: Update build status
runs-on: ubuntu-20.04
+ permissions:
+ actions: read
+ checks: write
steps:
- name: "Update build status"
- uses: actions/github-script@f05a81df23035049204b043b50c3322045ce7eb3 # pin@v3
+ uses: actions/github-script@v6
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@@ -58,7 +61,7 @@ jobs:
// Iterator GitHub Checks in the PR
for await (const cr of checkRuns.data.check_runs) {
- if (cr.name == 'Build and test' && cr.conclusion != "action_required") {
+ if (cr.name == 'Build' && cr.conclusion != "action_required") {
// text contains parameters to make request in JSON.
const params = JSON.parse(cr.output.text)
diff --git a/.gitignore b/.gitignore
index 0e2f59f43f83d..11141961bf805 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,10 +18,7 @@
.ensime_cache/
.ensime_lucene
.generated-mima*
-# All the files under .idea/ are ignore. To add new files under ./idea that are not in the VCS yet, please use `git add -f`
.idea/
-# SPARK-35223: Add IssueNavigationLink to make IDEA support hyperlink on JIRA Ticket and GitHub PR on Git plugin.
-!.idea/vcs.xml
.idea_modules/
.metals
.project
@@ -77,6 +74,7 @@ python/coverage.xml
python/deps
python/docs/_site/
python/docs/source/reference/**/api/
+python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
python/test_coverage/coverage_data
python/test_coverage/htmlcov
python/pyspark/python
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
deleted file mode 100644
index 28fd3fcdf10ea..0000000000000
--- a/.idea/vcs.xml
+++ /dev/null
@@ -1,36 +0,0 @@
-
-
-
-
-
-
-
-
-
-
diff --git a/LICENSE b/LICENSE
index df6bed16f4471..012fdbca4c90d 100644
--- a/LICENSE
+++ b/LICENSE
@@ -216,7 +216,7 @@ core/src/main/resources/org/apache/spark/ui/static/bootstrap*
core/src/main/resources/org/apache/spark/ui/static/jsonFormatter*
core/src/main/resources/org/apache/spark/ui/static/vis*
docs/js/vendor/bootstrap.js
-external/spark-ganglia-lgpl/src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java
+connector/spark-ganglia-lgpl/src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java
Python Software Foundation License
diff --git a/LICENSE-binary b/LICENSE-binary
index 40e2e389b2264..9472d28e509ac 100644
--- a/LICENSE-binary
+++ b/LICENSE-binary
@@ -382,6 +382,10 @@ org.eclipse.jetty:jetty-servlets
org.eclipse.jetty:jetty-util
org.eclipse.jetty:jetty-webapp
org.eclipse.jetty:jetty-xml
+org.scala-lang:scala-compiler
+org.scala-lang:scala-library
+org.scala-lang:scala-reflect
+org.scala-lang.modules:scala-parser-combinators_2.12
org.scala-lang.modules:scala-xml_2.12
com.github.joshelser:dropwizard-metrics-hadoop-metrics2-reporter
com.zaxxer.HikariCP
@@ -404,6 +408,7 @@ org.datanucleus:javax.jdo
com.tdunning:json
org.apache.velocity:velocity
org.apache.yetus:audience-annotations
+com.google.cloud.bigdataoss:gcs-connector
core/src/main/java/org/apache/spark/util/collection/TimSort.java
core/src/main/resources/org/apache/spark/ui/static/bootstrap*
@@ -426,7 +431,6 @@ javolution:javolution
com.esotericsoftware:kryo-shaded
com.esotericsoftware:minlog
com.esotericsoftware:reflectasm
-com.google.protobuf:protobuf-java
org.codehaus.janino:commons-compiler
org.codehaus.janino:janino
jline:jline
@@ -438,6 +442,7 @@ pl.edu.icm:JLargeArrays
BSD 3-Clause
------------
+com.google.protobuf:protobuf-java
dk.brics.automaton:automaton
org.antlr:antlr-runtime
org.antlr:ST4
@@ -445,10 +450,6 @@ org.antlr:stringtemplate
org.antlr:antlr4-runtime
antlr:antlr
com.thoughtworks.paranamer:paranamer
-org.scala-lang:scala-compiler
-org.scala-lang:scala-library
-org.scala-lang:scala-reflect
-org.scala-lang.modules:scala-parser-combinators_2.12
org.fusesource.leveldbjni:leveldbjni-all
net.sourceforge.f2j:arpack_combined_all
xmlenc:xmlenc
diff --git a/R/check-cran.sh b/R/check-cran.sh
index 22c8f423cfd12..4123361f5e285 100755
--- a/R/check-cran.sh
+++ b/R/check-cran.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
diff --git a/R/create-docs.sh b/R/create-docs.sh
index 4867fd99e647c..3deaefd0659dc 100755
--- a/R/create-docs.sh
+++ b/R/create-docs.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
diff --git a/R/create-rd.sh b/R/create-rd.sh
index 72a932c175c95..1f0527458f2f0 100755
--- a/R/create-rd.sh
+++ b/R/create-rd.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
diff --git a/R/find-r.sh b/R/find-r.sh
index 690acc083af91..f1a5026911a7f 100755
--- a/R/find-r.sh
+++ b/R/find-r.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
diff --git a/R/install-dev.sh b/R/install-dev.sh
index 9fbc999f2e805..7df21c6c5ec9a 100755
--- a/R/install-dev.sh
+++ b/R/install-dev.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
diff --git a/R/install-source-package.sh b/R/install-source-package.sh
index 8de3569d1d482..0a2a5fe00f31f 100755
--- a/R/install-source-package.sh
+++ b/R/install-source-package.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 0e449e841cf6d..fa7028630a899 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -1,6 +1,6 @@
Package: SparkR
Type: Package
-Version: 3.3.1
+Version: 3.4.1
Title: R Front End for 'Apache Spark'
Description: Provides an R Front end for 'Apache Spark' .
Authors@R:
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 6e0557cff88ce..bb05e99a9d8a6 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -143,6 +143,7 @@ exportMethods("arrange",
"join",
"limit",
"localCheckpoint",
+ "melt",
"merge",
"mutate",
"na.omit",
@@ -182,6 +183,7 @@ exportMethods("arrange",
"unionByName",
"unique",
"unpersist",
+ "unpivot",
"where",
"with",
"withColumn",
@@ -474,9 +476,16 @@ export("as.DataFrame",
"createDataFrame",
"createExternalTable",
"createTable",
+ "currentCatalog",
"currentDatabase",
+ "databaseExists",
"dropTempTable",
"dropTempView",
+ "functionExists",
+ "getDatabase",
+ "getFunc",
+ "getTable",
+ "listCatalogs",
"listColumns",
"listDatabases",
"listFunctions",
@@ -493,6 +502,7 @@ export("as.DataFrame",
"refreshByPath",
"refreshTable",
"setCheckpointDir",
+ "setCurrentCatalog",
"setCurrentDatabase",
"spark.lapply",
"spark.addFile",
@@ -500,6 +510,7 @@ export("as.DataFrame",
"spark.getSparkFiles",
"sql",
"str",
+ "tableExists",
"tableToDF",
"tableNames",
"tables",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index e143cbd8256f9..3f9bc9cb6d053 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -3366,7 +3366,7 @@ setMethod("na.omit",
setMethod("fillna",
signature(x = "SparkDataFrame"),
function(x, value, cols = NULL) {
- if (!(class(value) %in% c("integer", "numeric", "character", "list"))) {
+ if (!(inherits(value, c("integer", "numeric", "character", "list")))) {
stop("value should be an integer, numeric, character or named list.")
}
@@ -3378,7 +3378,7 @@ setMethod("fillna",
}
# Check each item in the named list is of valid type
lapply(value, function(v) {
- if (!(class(v) %in% c("integer", "numeric", "character"))) {
+ if (!(inherits(v, c("integer", "numeric", "character")))) {
stop("Each item in value should be an integer, numeric or character.")
}
})
@@ -3577,41 +3577,56 @@ setMethod("str",
#' This is a no-op if schema doesn't contain column name(s).
#'
#' @param x a SparkDataFrame.
-#' @param col a character vector of column names or a Column.
-#' @param ... further arguments to be passed to or from other methods.
-#' @return A SparkDataFrame.
+#' @param col a list of columns or single Column or name.
+#' @param ... additional column(s) if only one column is specified in \code{col}.
+#' If more than one column is assigned in \code{col}, \code{...}
+#' should be left empty.
+#' @return A new SparkDataFrame with selected columns.
#'
#' @family SparkDataFrame functions
#' @rdname drop
#' @name drop
-#' @aliases drop,SparkDataFrame-method
+#' @aliases drop,SparkDataFrame,characterOrColumn-method
#' @examples
-#'\dontrun{
+#' \dontrun{
#' sparkR.session()
#' path <- "path/to/file.json"
#' df <- read.json(path)
#' drop(df, "col1")
#' drop(df, c("col1", "col2"))
#' drop(df, df$col1)
+#' drop(df, "col1", "col2")
+#' drop(df, df$name, df$age)
#' }
-#' @note drop since 2.0.0
+#' @note drop(SparkDataFrame, characterOrColumn, ...) since 3.4.0
setMethod("drop",
- signature(x = "SparkDataFrame"),
- function(x, col) {
- stopifnot(class(col) == "character" || class(col) == "Column")
-
- if (class(col) == "Column") {
- sdf <- callJMethod(x@sdf, "drop", col@jc)
+ signature(x = "SparkDataFrame", col = "characterOrColumn"),
+ function(x, col, ...) {
+ if (class(col) == "character" && length(col) > 1) {
+ if (length(list(...)) > 0) {
+ stop("To drop multiple columns, use a character vector or ... for character/Column")
+ }
+ cols <- as.list(col)
} else {
- sdf <- callJMethod(x@sdf, "drop", as.list(col))
+ cols <- list(col, ...)
}
+
+ cols <- lapply(cols, function(c) {
+ if (class(c) == "Column") {
+ c@jc
+ } else {
+ col(c)@jc
+ }
+ })
+
+ sdf <- callJMethod(x@sdf, "drop", cols[[1]], cols[-1])
dataFrame(sdf)
})
# Expose base::drop
#' @name drop
#' @rdname drop
-#' @aliases drop,ANY-method
+#' @aliases drop,ANY,ANY-method
setMethod("drop",
signature(x = "ANY"),
function(x) {
@@ -4238,3 +4253,76 @@ setMethod("withWatermark",
sdf <- callJMethod(x@sdf, "withWatermark", eventTime, delayThreshold)
dataFrame(sdf)
})
+
+#' Unpivot a DataFrame from wide format to long format.
+#'
+#' This is the reverse to \code{groupBy(...).pivot(...).agg(...)},
+#' except for the aggregation, which cannot be reversed.
+#'
+#' @param x a SparkDataFrame.
+#' @param ids a character vector or a list of columns
+#' @param values a character vector, a list of columns or \code{NULL}.
+#' If not NULL must not be empty. If \code{NULL}, uses all columns that
+#' are not set as \code{ids}.
+#' @param variableColumnName character Name of the variable column.
+#' @param valueColumnName character Name of the value column.
+#' @return a SparkDataFrame.
+#' @aliases unpivot,SparkDataFrame,ANY,ANY,character,character-method
+#' @family SparkDataFrame functions
+#' @rdname unpivot
+#' @name unpivot
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(data.frame(
+#' id = 1:3, x = c(1, 3, 5), y = c(2, 4, 6), z = c(-1, 0, 1)
+#' ))
+#'
+#' head(unpivot(df, "id", c("x", "y"), "var", "val"))
+#'
+#' head(unpivot(df, "id", NULL, "var", "val"))
+#' }
+#' @note unpivot since 3.4.0
+setMethod("unpivot",
+ signature(
+ x = "SparkDataFrame", ids = "ANY", values = "ANY",
+ variableColumnName = "character", valueColumnName = "character"
+ ),
+ function(x, ids, values, variableColumnName, valueColumnName) {
+ as_jcols <- function(xs) lapply(
+ xs,
+ function(x) {
+ if (is.character(x)) {
+ column(x)@jc
+ } else {
+ c@jc
+ }
+ }
+ )
+
+ sdf <- if (is.null(values)) {
+ callJMethod(
+ x@sdf, "unpivotWithSeq", as_jcols(ids), variableColumnName, valueColumnName
+ )
+ } else {
+ callJMethod(
+ x@sdf, "unpivotWithSeq",
+ as_jcols(ids), as_jcols(values),
+ variableColumnName, valueColumnName
+ )
+ }
+ dataFrame(sdf)
+ })
+
+#' @rdname unpivot
+#' @name melt
+#' @aliases melt,SparkDataFrame,ANY,ANY,character,character-method
+#' @note melt since 3.4.0
+setMethod("melt",
+ signature(
+ x = "SparkDataFrame", ids = "ANY", values = "ANY",
+ variableColumnName = "character", valueColumnName = "character"
+ ),
+ function(x, ids, values, variableColumnName, valueColumnName) {
+ unpivot(x, ids, values, variableColumnName, valueColumnName)
+ }
+)
diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R
index be47d0117ed7f..5c1de0beac3ca 100644
--- a/R/pkg/R/WindowSpec.R
+++ b/R/pkg/R/WindowSpec.R
@@ -135,7 +135,7 @@ setMethod("orderBy",
#' An offset indicates the number of rows above or below the current row, the frame for the
#' current row starts or ends. For instance, given a row based sliding frame with a lower bound
#' offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from
-#' index 4 to index 6.
+#' index 4 to index 7.
#'
#' @param x a WindowSpec
#' @param start boundary start, inclusive.
diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R
index 275737f804bde..942af4de3c0bb 100644
--- a/R/pkg/R/catalog.R
+++ b/R/pkg/R/catalog.R
@@ -17,6 +17,66 @@
# catalog.R: SparkSession catalog functions
+#' Returns the current default catalog
+#'
+#' Returns the current default catalog.
+#'
+#' @return name of the current default catalog.
+#' @rdname currentCatalog
+#' @name currentCatalog
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' currentCatalog()
+#' }
+#' @note since 3.4.0
+currentCatalog <- function() {
+ sparkSession <- getSparkSession()
+ catalog <- callJMethod(sparkSession, "catalog")
+ callJMethod(catalog, "currentCatalog")
+}
+
+#' Sets the current default catalog
+#'
+#' Sets the current default catalog.
+#'
+#' @param catalogName name of the catalog
+#' @rdname setCurrentCatalog
+#' @name setCurrentCatalog
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' setCurrentCatalog("spark_catalog")
+#' }
+#' @note since 3.4.0
+setCurrentCatalog <- function(catalogName) {
+ sparkSession <- getSparkSession()
+ if (class(catalogName) != "character") {
+ stop("catalogName must be a string.")
+ }
+ catalog <- callJMethod(sparkSession, "catalog")
+ invisible(handledCallJMethod(catalog, "setCurrentCatalog", catalogName))
+}
+
+#' Returns a list of catalog available
+#'
+#' Returns a list of catalog available.
+#'
+#' @return a SparkDataFrame of the list of catalog.
+#' @rdname listCatalogs
+#' @name listCatalogs
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' listCatalogs()
+#' }
+#' @note since 3.4.0
+listCatalogs <- function() {
+ sparkSession <- getSparkSession()
+ catalog <- callJMethod(sparkSession, "catalog")
+ dataFrame(callJMethod(callJMethod(catalog, "listCatalogs"), "toDF"))
+}
+
#' (Deprecated) Create an external table
#'
#' Creates an external table based on the dataset in a data source,
@@ -58,6 +118,7 @@ createExternalTable <- function(tableName, path = NULL, source = NULL, schema =
#'
#' @param tableName the qualified or unqualified name that designates a table. If no database
#' identifier is provided, it refers to a table in the current database.
+#' The table name can be fully qualified with catalog name since 3.4.0.
#' @param path (optional) the path of files to load.
#' @param source (optional) the name of the data source.
#' @param schema (optional) the schema of the data required for some data sources.
@@ -69,7 +130,7 @@ createExternalTable <- function(tableName, path = NULL, source = NULL, schema =
#' sparkR.session()
#' df <- createTable("myjson", path="path/to/json", source="json", schema)
#'
-#' createTable("people", source = "json", schema = schema)
+#' createTable("spark_catalog.default.people", source = "json", schema = schema)
#' insertInto(df, "people")
#' }
#' @name createTable
@@ -100,6 +161,7 @@ createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, ..
#'
#' @param tableName the qualified or unqualified name that designates a table. If no database
#' identifier is provided, it refers to a table in the current database.
+#' The table name can be fully qualified with catalog name since 3.4.0.
#' @return SparkDataFrame
#' @rdname cacheTable
#' @examples
@@ -124,6 +186,7 @@ cacheTable <- function(tableName) {
#'
#' @param tableName the qualified or unqualified name that designates a table. If no database
#' identifier is provided, it refers to a table in the current database.
+#' The table name can be fully qualified with catalog name since 3.4.0.
#' @return SparkDataFrame
#' @rdname uncacheTable
#' @examples
@@ -215,13 +278,14 @@ dropTempView <- function(viewName) {
#' Returns a SparkDataFrame containing names of tables in the given database.
#'
#' @param databaseName (optional) name of the database
+#' The database name can be qualified with catalog name since 3.4.0.
#' @return a SparkDataFrame
#' @rdname tables
#' @seealso \link{listTables}
#' @examples
#'\dontrun{
#' sparkR.session()
-#' tables("hive")
+#' tables("spark_catalog.hive")
#' }
#' @name tables
#' @note tables since 1.4.0
@@ -235,12 +299,13 @@ tables <- function(databaseName = NULL) {
#' Returns the names of tables in the given database as an array.
#'
#' @param databaseName (optional) name of the database
+#' The database name can be qualified with catalog name since 3.4.0.
#' @return a list of table names
#' @rdname tableNames
#' @examples
#'\dontrun{
#' sparkR.session()
-#' tableNames("hive")
+#' tableNames("spark_catalog.hive")
#' }
#' @name tableNames
#' @note tableNames since 1.4.0
@@ -293,6 +358,28 @@ setCurrentDatabase <- function(databaseName) {
invisible(handledCallJMethod(catalog, "setCurrentDatabase", databaseName))
}
+#' Checks if the database with the specified name exists.
+#'
+#' Checks if the database with the specified name exists.
+#'
+#' @param databaseName name of the database, allowed to be qualified with catalog name
+#' @rdname databaseExists
+#' @name databaseExists
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' databaseExists("spark_catalog.default")
+#' }
+#' @note since 3.4.0
+databaseExists <- function(databaseName) {
+ sparkSession <- getSparkSession()
+ if (class(databaseName) != "character") {
+ stop("databaseName must be a string.")
+ }
+ catalog <- callJMethod(sparkSession, "catalog")
+ callJMethod(catalog, "databaseExists", databaseName)
+}
+
#' Returns a list of databases available
#'
#' Returns a list of databases available.
@@ -312,12 +399,54 @@ listDatabases <- function() {
dataFrame(callJMethod(callJMethod(catalog, "listDatabases"), "toDF"))
}
+#' Get the database with the specified name
+#'
+#' Get the database with the specified name
+#'
+#' @param databaseName name of the database, allowed to be qualified with catalog name
+#' @return A named list.
+#' @rdname getDatabase
+#' @name getDatabase
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' db <- getDatabase("default")
+#' }
+#' @note since 3.4.0
+getDatabase <- function(databaseName) {
+ sparkSession <- getSparkSession()
+ if (class(databaseName) != "character") {
+ stop("databaseName must be a string.")
+ }
+ catalog <- callJMethod(sparkSession, "catalog")
+ jdb <- handledCallJMethod(catalog, "getDatabase", databaseName)
+
+ ret <- list(name = callJMethod(jdb, "name"))
+ jcata <- callJMethod(jdb, "catalog")
+ if (is.null(jcata)) {
+ ret$catalog <- NA
+ } else {
+ ret$catalog <- jcata
+ }
+
+ jdesc <- callJMethod(jdb, "description")
+ if (is.null(jdesc)) {
+ ret$description <- NA
+ } else {
+ ret$description <- jdesc
+ }
+
+ ret$locationUri <- callJMethod(jdb, "locationUri")
+ ret
+}
+
#' Returns a list of tables or views in the specified database
#'
#' Returns a list of tables or views in the specified database.
#' This includes all temporary views.
#'
#' @param databaseName (optional) name of the database
+#' The database name can be qualified with catalog name since 3.4.0.
#' @return a SparkDataFrame of the list of tables.
#' @rdname listTables
#' @name listTables
@@ -326,7 +455,7 @@ listDatabases <- function() {
#' \dontrun{
#' sparkR.session()
#' listTables()
-#' listTables("default")
+#' listTables("spark_catalog.default")
#' }
#' @note since 2.2.0
listTables <- function(databaseName = NULL) {
@@ -343,6 +472,78 @@ listTables <- function(databaseName = NULL) {
dataFrame(callJMethod(jdst, "toDF"))
}
+#' Checks if the table with the specified name exists.
+#'
+#' Checks if the table with the specified name exists.
+#'
+#' @param tableName name of the table, allowed to be qualified with catalog name
+#' @rdname tableExists
+#' @name tableExists
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' databaseExists("spark_catalog.default.myTable")
+#' }
+#' @note since 3.4.0
+tableExists <- function(tableName) {
+ sparkSession <- getSparkSession()
+ if (class(tableName) != "character") {
+ stop("tableName must be a string.")
+ }
+ catalog <- callJMethod(sparkSession, "catalog")
+ callJMethod(catalog, "tableExists", tableName)
+}
+
+#' Get the table with the specified name
+#'
+#' Get the table with the specified name
+#'
+#' @param tableName the qualified or unqualified name that designates a table, allowed to be
+#' qualified with catalog name
+#' @return A named list.
+#' @rdname getTable
+#' @name getTable
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' tbl <- getTable("spark_catalog.default.myTable")
+#' }
+#' @note since 3.4.0
+getTable <- function(tableName) {
+ sparkSession <- getSparkSession()
+ if (class(tableName) != "character") {
+ stop("tableName must be a string.")
+ }
+ catalog <- callJMethod(sparkSession, "catalog")
+ jtbl <- handledCallJMethod(catalog, "getTable", tableName)
+
+ ret <- list(name = callJMethod(jtbl, "name"))
+ jcata <- callJMethod(jtbl, "catalog")
+ if (is.null(jcata)) {
+ ret$catalog <- NA
+ } else {
+ ret$catalog <- jcata
+ }
+
+ jns <- callJMethod(jtbl, "namespace")
+ if (is.null(jns)) {
+ ret$namespace <- NA
+ } else {
+ ret$namespace <- jns
+ }
+
+ jdesc <- callJMethod(jtbl, "description")
+ if (is.null(jdesc)) {
+ ret$description <- NA
+ } else {
+ ret$description <- jdesc
+ }
+
+ ret$tableType <- callJMethod(jtbl, "tableType")
+ ret$isTemporary <- callJMethod(jtbl, "isTemporary")
+ ret
+}
+
#' Returns a list of columns for the given table/view in the specified database
#'
#' Returns a list of columns for the given table/view in the specified database.
@@ -350,6 +551,8 @@ listTables <- function(databaseName = NULL) {
#' @param tableName the qualified or unqualified name that designates a table/view. If no database
#' identifier is provided, it refers to a table/view in the current database.
#' If \code{databaseName} parameter is specified, this must be an unqualified name.
+#' The table name can be qualified with catalog name since 3.4.0, when databaseName
+#' is NULL.
#' @param databaseName (optional) name of the database
#' @return a SparkDataFrame of the list of column descriptions.
#' @rdname listColumns
@@ -357,7 +560,7 @@ listTables <- function(databaseName = NULL) {
#' @examples
#' \dontrun{
#' sparkR.session()
-#' listColumns("mytable")
+#' listColumns("spark_catalog.default.mytable")
#' }
#' @note since 2.2.0
listColumns <- function(tableName, databaseName = NULL) {
@@ -380,13 +583,14 @@ listColumns <- function(tableName, databaseName = NULL) {
#' This includes all temporary functions.
#'
#' @param databaseName (optional) name of the database
+#' The database name can be qualified with catalog name since 3.4.0.
#' @return a SparkDataFrame of the list of function descriptions.
#' @rdname listFunctions
#' @name listFunctions
#' @examples
#' \dontrun{
#' sparkR.session()
-#' listFunctions()
+#' listFunctions(spark_catalog.default)
#' }
#' @note since 2.2.0
listFunctions <- function(databaseName = NULL) {
@@ -403,6 +607,78 @@ listFunctions <- function(databaseName = NULL) {
dataFrame(callJMethod(jdst, "toDF"))
}
+#' Checks if the function with the specified name exists.
+#'
+#' Checks if the function with the specified name exists.
+#'
+#' @param functionName name of the function, allowed to be qualified with catalog name
+#' @rdname functionExists
+#' @name functionExists
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' functionExists("spark_catalog.default.myFunc")
+#' }
+#' @note since 3.4.0
+functionExists <- function(functionName) {
+ sparkSession <- getSparkSession()
+ if (class(functionName) != "character") {
+ stop("functionName must be a string.")
+ }
+ catalog <- callJMethod(sparkSession, "catalog")
+ callJMethod(catalog, "functionExists", functionName)
+}
+
+#' Get the function with the specified name
+#'
+#' Get the function with the specified name
+#'
+#' @param functionName name of the function, allowed to be qualified with catalog name
+#' @return A named list.
+#' @rdname getFunc
+#' @name getFunc
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' func <- getFunc("spark_catalog.default.myFunc")
+#' }
+#' @note since 3.4.0. Use different name with the scala/python side, to avoid the
+#' signature conflict with built-in "getFunction".
+getFunc <- function(functionName) {
+ sparkSession <- getSparkSession()
+ if (class(functionName) != "character") {
+ stop("functionName must be a string.")
+ }
+ catalog <- callJMethod(sparkSession, "catalog")
+ jfunc <- handledCallJMethod(catalog, "getFunction", functionName)
+
+ ret <- list(name = callJMethod(jfunc, "name"))
+ jcata <- callJMethod(jfunc, "catalog")
+ if (is.null(jcata)) {
+ ret$catalog <- NA
+ } else {
+ ret$catalog <- jcata
+ }
+
+ jns <- callJMethod(jfunc, "namespace")
+ if (is.null(jns)) {
+ ret$namespace <- NA
+ } else {
+ ret$namespace <- jns
+ }
+
+ jdesc <- callJMethod(jfunc, "description")
+ if (is.null(jdesc)) {
+ ret$description <- NA
+ } else {
+ ret$description <- jdesc
+ }
+
+ ret$className <- callJMethod(jfunc, "className")
+ ret$isTemporary <- callJMethod(jfunc, "isTemporary")
+ ret
+}
+
#' Recovers all the partitions in the directory of a table and update the catalog
#'
#' Recovers all the partitions in the directory of a table and update the catalog. The name should
@@ -410,12 +686,13 @@ listFunctions <- function(databaseName = NULL) {
#'
#' @param tableName the qualified or unqualified name that designates a table. If no database
#' identifier is provided, it refers to a table in the current database.
+#' The table name can be fully qualified with catalog name since 3.4.0.
#' @rdname recoverPartitions
#' @name recoverPartitions
#' @examples
#' \dontrun{
#' sparkR.session()
-#' recoverPartitions("myTable")
+#' recoverPartitions("spark_catalog.default.myTable")
#' }
#' @note since 2.2.0
recoverPartitions <- function(tableName) {
@@ -436,12 +713,13 @@ recoverPartitions <- function(tableName) {
#'
#' @param tableName the qualified or unqualified name that designates a table. If no database
#' identifier is provided, it refers to a table in the current database.
+#' The table name can be fully qualified with catalog name since 3.4.0.
#' @rdname refreshTable
#' @name refreshTable
#' @examples
#' \dontrun{
#' sparkR.session()
-#' refreshTable("myTable")
+#' refreshTable("spark_catalog.default.myTable")
#' }
#' @note since 2.2.0
refreshTable <- function(tableName) {
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index f1fd30e144bb6..e4865056f58bc 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -85,7 +85,7 @@ createOperator <- function(op) {
callJMethod(e1@jc, operators[[op]])
}
} else {
- if (class(e2) == "Column") {
+ if (inherits(e2, "Column")) {
e2 <- e2@jc
}
if (op == "^") {
@@ -110,7 +110,7 @@ createColumnFunction2 <- function(name) {
setMethod(name,
signature(x = "Column"),
function(x, data) {
- if (class(data) == "Column") {
+ if (inherits(data, "Column")) {
data <- data@jc
}
jc <- callJMethod(x@jc, name, data)
@@ -306,7 +306,7 @@ setMethod("%in%",
setMethod("otherwise",
signature(x = "Column", value = "ANY"),
function(x, value) {
- value <- if (class(value) == "Column") { value@jc } else { value }
+ value <- if (inherits(value, "Column")) { value@jc } else { value }
jc <- callJMethod(x@jc, "otherwise", value)
column(jc)
})
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index cca6c2c817de9..eea83aa5ab527 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -170,7 +170,7 @@ parallelize <- function(sc, coll, numSlices = 1) {
serializedSlices <- lapply(slices, serialize, connection = NULL)
# The RPC backend cannot handle arguments larger than 2GB (INT_MAX)
- # If serialized data is safely less than that threshold we send it over the PRC channel.
+ # If serialized data is safely less than that threshold we send it over the RPC channel.
# Otherwise, we write it to a file and send the file name
if (objectSize < sizeLimit) {
jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices)
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 1377f0daa7360..00ce630bd18e3 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -258,6 +258,13 @@ NULL
#' into accumulator (the first argument).
#' @param finish an unary \code{function} \code{(Column) -> Column} used to
#' apply final transformation on the accumulated data in \code{array_aggregate}.
+#' @param comparator an optional binary (\code{(Column, Column) -> Column}) \code{function}
+#' which is used to compare the elemnts of the array.
+#' The comparator will take two
+#' arguments representing two elements of the array. It returns a negative integer,
+#' 0, or a positive integer as the first element is less than, equal to,
+#' or greater than the second element.
+#' If the comparator function returns null, the function will fail and raise an error.
#' @param ... additional argument(s).
#' \itemize{
#' \item \code{to_json}, \code{from_json} and \code{schema_of_json}: this contains
@@ -292,6 +299,7 @@ NULL
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1), shuffle(tmp$v1)))
#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1), array_distinct(tmp$v1)))
#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1)))
+#' head(select(tmp, array_sort(tmp$v1, function(x, y) coalesce(cast(y - x, "integer"), lit(0L)))))
#' head(select(tmp, reverse(tmp$v1), array_remove(tmp$v1, 21)))
#' head(select(tmp, array_transform("v1", function(x) x * 10)))
#' head(select(tmp, array_exists("v1", function(x) x > 120)))
@@ -445,7 +453,7 @@ setMethod("lit", signature("ANY"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions",
"lit",
- if (class(x) == "Column") { x@jc } else { x })
+ if (inherits(x, "Column")) { x@jc } else { x })
column(jc)
})
@@ -966,7 +974,7 @@ setMethod("hash",
#' @details
#' \code{xxhash64}: Calculates the hash code of given columns using the 64-bit
#' variant of the xxHash algorithm, and returns the result as a long
-#' column.
+#' column. The hash computation uses an initial seed of 42.
#'
#' @rdname column_misc_functions
#' @aliases xxhash64 xxhash64,Column-method
@@ -3256,7 +3264,8 @@ setMethod("format_string", signature(format = "character", x = "Column"),
#' tmp <- mutate(df, to_unix = unix_timestamp(df$time),
#' to_unix2 = unix_timestamp(df$time, 'yyyy-MM-dd HH'),
#' from_unix = from_unixtime(unix_timestamp(df$time)),
-#' from_unix2 = from_unixtime(unix_timestamp(df$time), 'yyyy-MM-dd HH:mm'))
+#' from_unix2 = from_unixtime(unix_timestamp(df$time), 'yyyy-MM-dd HH:mm'),
+#' timestamp_from_unix = timestamp_seconds(unix_timestamp(df$time)))
#' head(tmp)}
#' @note from_unixtime since 1.5.0
setMethod("from_unixtime", signature(x = "Column"),
@@ -3586,7 +3595,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"),
setMethod("when", signature(condition = "Column", value = "ANY"),
function(condition, value) {
condition <- condition@jc
- value <- if (class(value) == "Column") { value@jc } else { value }
+ value <- if (inherits(value, "Column")) { value@jc } else { value }
jc <- callJStatic("org.apache.spark.sql.functions", "when", condition, value)
column(jc)
})
@@ -3605,8 +3614,8 @@ setMethod("ifelse",
signature(test = "Column", yes = "ANY", no = "ANY"),
function(test, yes, no) {
test <- test@jc
- yes <- if (class(yes) == "Column") { yes@jc } else { yes }
- no <- if (class(no) == "Column") { no@jc } else { no }
+ yes <- if (inherits(yes, "Column")) { yes@jc } else { yes }
+ no <- if (inherits(no, "Column")) { no@jc } else { no }
jc <- callJMethod(callJStatic("org.apache.spark.sql.functions",
"when",
test, yes),
@@ -4140,9 +4149,16 @@ setMethod("array_repeat",
#' @note array_sort since 2.4.0
setMethod("array_sort",
signature(x = "Column"),
- function(x) {
- jc <- callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc)
- column(jc)
+ function(x, comparator = NULL) {
+ if (is.null(comparator)) {
+ column(callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc))
+ } else {
+ invoke_higher_order_function(
+ "ArraySort",
+ cols = list(x),
+ funs = list(comparator)
+ )
+ }
})
#' @details
@@ -4854,7 +4870,8 @@ setMethod("current_timestamp",
})
#' @details
-#' \code{timestamp_seconds}: Creates timestamp from the number of seconds since UTC epoch.
+#' \code{timestamp_seconds}: Converts the number of seconds from the Unix epoch
+#' (1970-01-01T00:00:00Z) to a timestamp.
#'
#' @rdname column_datetime_functions
#' @aliases timestamp_seconds timestamp_seconds,Column-method
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 5fe2ec602ecd3..328df50877b70 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -442,7 +442,7 @@ setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
setGeneric("distinct", function(x) { standardGeneric("distinct") })
#' @rdname drop
-setGeneric("drop", function(x, ...) { standardGeneric("drop") })
+setGeneric("drop", function(x, col, ...) { standardGeneric("drop") })
#' @rdname dropDuplicates
setGeneric("dropDuplicates", function(x, ...) { standardGeneric("dropDuplicates") })
@@ -670,6 +670,16 @@ setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSp
#' @rdname broadcast
setGeneric("broadcast", function(x) { standardGeneric("broadcast") })
+#' @rdname unpivot
+setGeneric("unpivot", function(x, ids, values, variableColumnName, valueColumnName) {
+ standardGeneric("unpivot")
+})
+
+#' @rdname melt
+setGeneric("melt", function(x, ids, values, variableColumnName, valueColumnName) {
+ standardGeneric("melt")
+})
+
###################### Column Methods ##########################
#' @rdname columnfunctions
@@ -840,7 +850,7 @@ setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat")
#' @rdname column_collection_functions
#' @name NULL
-setGeneric("array_sort", function(x) { standardGeneric("array_sort") })
+setGeneric("array_sort", function(x, ...) { standardGeneric("array_sort") })
#' @rdname column_ml_functions
#' @name NULL
diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R
index bbb9188cd083f..971de6010eb8a 100644
--- a/R/pkg/R/install.R
+++ b/R/pkg/R/install.R
@@ -29,19 +29,18 @@
#' \code{mirrorUrl} specifies the remote path to a Spark folder. It is followed by a subfolder
#' named after the Spark version (that corresponds to SparkR), and then the tar filename.
#' The filename is composed of four parts, i.e. [Spark version]-bin-[Hadoop version].tgz.
-#' For example, the full path for a Spark 2.0.0 package for Hadoop 2.7 from
-#' \code{http://apache.osuosl.org} has path:
-#' \code{http://apache.osuosl.org/spark/spark-2.0.0/spark-2.0.0-bin-hadoop2.7.tgz}.
+#' For example, the full path for a Spark 3.3.1 package from
+#' \code{https://archive.apache.org} has path:
+#' \code{http://archive.apache.org/dist/spark/spark-3.3.1/spark-3.3.1-bin-hadoop3.tgz}.
#' For \code{hadoopVersion = "without"}, [Hadoop version] in the filename is then
#' \code{without-hadoop}.
#'
-#' @param hadoopVersion Version of Hadoop to install. Default is \code{"2.7"}. It can take other
-#' version number in the format of "x.y" where x and y are integer.
+#' @param hadoopVersion Version of Hadoop to install. Default is \code{"3"}.
#' If \code{hadoopVersion = "without"}, "Hadoop free" build is installed.
#' See
#' \href{https://spark.apache.org/docs/latest/hadoop-provided.html}{
#' "Hadoop Free" Build} for more information.
-#' Other patched version names can also be used, e.g. \code{"cdh4"}
+#' Other patched version names can also be used.
#' @param mirrorUrl base URL of the repositories to use. The directory layout should follow
#' \href{https://www.apache.org/dyn/closer.lua/spark/}{Apache mirrors}.
#' @param localDir a local directory where Spark is installed. The directory contains
@@ -65,7 +64,7 @@
#' @note install.spark since 2.1.0
#' @seealso See available Hadoop versions:
#' \href{https://spark.apache.org/downloads.html}{Apache Spark}
-install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL,
+install.spark <- function(hadoopVersion = "3", mirrorUrl = NULL,
localDir = NULL, overwrite = FALSE) {
sparkHome <- Sys.getenv("SPARK_HOME")
if (isSparkRShell()) {
@@ -251,7 +250,7 @@ defaultMirrorUrl <- function() {
hadoopVersionName <- function(hadoopVersion) {
if (hadoopVersion == "without") {
"without-hadoop"
- } else if (grepl("^[0-9]+\\.[0-9]+$", hadoopVersion, perl = TRUE)) {
+ } else if (grepl("^[0-9]+$", hadoopVersion, perl = TRUE)) {
paste0("hadoop", hadoopVersion)
} else {
hadoopVersion
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index 093467ecf7d28..7204f8bb7dff4 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -322,7 +322,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
}
if (!is.null(lowerBoundsOnCoefficients)) {
- if (class(lowerBoundsOnCoefficients) != "matrix") {
+ if (!is.matrix(lowerBoundsOnCoefficients)) {
stop("lowerBoundsOnCoefficients must be a matrix.")
}
row <- nrow(lowerBoundsOnCoefficients)
@@ -331,7 +331,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
}
if (!is.null(upperBoundsOnCoefficients)) {
- if (class(upperBoundsOnCoefficients) != "matrix") {
+ if (!is.matrix(upperBoundsOnCoefficients)) {
stop("upperBoundsOnCoefficients must be a matrix.")
}
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 7760d9be16f0b..61e174de9ac56 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -58,7 +58,12 @@ writeObject <- function(con, object, writeType = TRUE) {
# Checking types is needed here, since 'is.na' only handles atomic vectors,
# lists and pairlists
if (type %in% c("integer", "character", "logical", "double", "numeric")) {
- if (is.na(object)) {
+ if (is.na(object[[1]])) {
+ # Uses the first element for now to keep the behavior same as R before
+ # 4.2.0. This is wrong because we should differenciate c(NA) from a
+ # single NA as the former means array(null) and the latter means null
+ # in Spark SQL. However, it requires non-trivial comparison to distinguish
+ # both in R. We should ideally fix this.
object <- NULL
type <- "NULL"
}
@@ -203,7 +208,11 @@ writeEnv <- function(con, env) {
}
writeDate <- function(con, date) {
- writeString(con, as.character(date))
+ if (is.na(date)) {
+ writeString(con, "NA")
+ } else {
+ writeString(con, as.character(date))
+ }
}
writeTime <- function(con, time) {
@@ -226,7 +235,7 @@ writeSerializeInArrow <- function(conn, df) {
# There looks no way to send each batch in streaming format via socket
# connection. See ARROW-4512.
# So, it writes the whole Arrow streaming-formatted binary at once for now.
- writeRaw(conn, arrow::write_arrow(df, raw()))
+ writeRaw(conn, arrow::write_to_raw(df))
} else {
stop("'arrow' package should be installed.")
}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index f18a6c7e25f1b..e2ab57471773c 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -40,8 +40,15 @@ sparkR.session.stop <- function() {
env <- .sparkREnv
if (exists(".sparkRCon", envir = env)) {
if (exists(".sparkRjsc", envir = env)) {
- sc <- get(".sparkRjsc", envir = env)
- callJMethod(sc, "stop")
+ # Should try catch for every use of the connection in case
+ # the connection is timed-out, see also SPARK-42186.
+ tryCatch({
+ sc <- get(".sparkRjsc", envir = env)
+ callJMethod(sc, "stop")
+ },
+ error = function(err) {
+ warning(err)
+ })
rm(".sparkRjsc", envir = env)
if (exists(".sparkRsession", envir = env)) {
@@ -56,20 +63,35 @@ sparkR.session.stop <- function() {
}
if (exists(".backendLaunched", envir = env)) {
- callJStatic("SparkRHandler", "stopBackend")
+ tryCatch({
+ callJStatic("SparkRHandler", "stopBackend")
+ },
+ error = function(err) {
+ warning(err)
+ })
}
# Also close the connection and remove it from our env
- conn <- get(".sparkRCon", envir = env)
- close(conn)
+ tryCatch({
+ conn <- get(".sparkRCon", envir = env)
+ close(conn)
+ },
+ error = function(err) {
+ warning(err)
+ })
rm(".sparkRCon", envir = env)
rm(".scStartTime", envir = env)
}
if (exists(".monitorConn", envir = env)) {
- conn <- get(".monitorConn", envir = env)
- close(conn)
+ tryCatch({
+ conn <- get(".monitorConn", envir = env)
+ close(conn)
+ },
+ error = function(err) {
+ warning(err)
+ })
rm(".monitorConn", envir = env)
}
diff --git a/R/pkg/pkgdown/_pkgdown_template.yml b/R/pkg/pkgdown/_pkgdown_template.yml
index eeb676befbc8b..e6b485d489844 100644
--- a/R/pkg/pkgdown/_pkgdown_template.yml
+++ b/R/pkg/pkgdown/_pkgdown_template.yml
@@ -117,6 +117,7 @@ reference:
- unionAll
- unionByName
- unpersist
+ - unpivot
- with
- withColumn
@@ -261,9 +262,16 @@ reference:
- title: "SQL Catalog"
- contents:
+ - currentCatalog
- currentDatabase
+ - databaseExists
- dropTempTable
- dropTempView
+ - functionExists
+ - getDatabase
+ - getFunc
+ - getTable
+ - listCatalogs
- listColumns
- listDatabases
- listFunctions
@@ -271,6 +279,9 @@ reference:
- refreshByPath
- refreshTable
- recoverPartitions
+ - setCurrentCatalog
+ - setCurrentDatabase
+ - tableExists
- tableNames
- tables
- uncacheTable
@@ -283,7 +294,6 @@ reference:
- getLocalProperty
- install.spark
- setCheckpointDir
- - setCurrentDatabase
- setJobDescription
- setJobGroup
- setLocalProperty
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index df1094bacef64..b0c56f1c15d06 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -154,7 +154,7 @@ test_that("structType and structField", {
expect_is(testSchema$fields()[[2]], "structField")
expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType")
- expect_error(structType("A stri"), "DataType stri is not supported.")
+ expect_error(structType("A stri"), ".*Unsupported data type \"STRI\".*")
})
test_that("structField type strings", {
@@ -495,7 +495,7 @@ test_that("SPARK-17902: collect() with stringsAsFactors enabled", {
expect_equal(iris$Species, df$Species)
})
-test_that("SPARK-17811: can create DataFrame containing NA as date and time", {
+test_that("SPARK-17811, SPARK-18011: can create DataFrame containing NA as date and time", {
df <- data.frame(
id = 1:2,
time = c(as.POSIXlt("2016-01-10"), NA),
@@ -622,7 +622,7 @@ test_that("read/write json files", {
# Test errorifexists
expect_error(write.df(df, jsonPath2, "json", mode = "errorifexists"),
- "analysis error - path file:.*already exists")
+ "Error in save : analysis error - \\[PATH_ALREADY_EXISTS\\].*")
# Test write.json
jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json")
@@ -663,7 +663,7 @@ test_that("test tableNames and tables", {
expect_equal(count(tables), count + 1)
expect_equal(count(tables()), count(tables))
expect_true("tableName" %in% colnames(tables()))
- expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables())))
+ expect_true(all(c("tableName", "namespace", "isTemporary") %in% colnames(tables())))
suppressWarnings(registerTempTable(df, "table2"))
tables <- listTables()
@@ -673,6 +673,22 @@ test_that("test tableNames and tables", {
tables <- listTables()
expect_equal(count(tables), count + 0)
+
+ count2 <- count(listTables())
+ schema <- structType(structField("name", "string"), structField("age", "integer"),
+ structField("height", "float"))
+ createTable("people", source = "json", schema = schema)
+
+ expect_equal(length(tableNames()), count2 + 1)
+ expect_equal(length(tableNames("default")), count2 + 1)
+ expect_equal(length(tableNames("spark_catalog.default")), count2 + 1)
+
+ tables <- listTables()
+ expect_equal(count(tables), count2 + 1)
+ expect_equal(count(tables()), count(tables))
+ expect_equal(count(tables("default")), count2 + 1)
+ expect_equal(count(tables("spark_catalog.default")), count2 + 1)
+ sql("DROP TABLE IF EXISTS people")
})
test_that(
@@ -696,16 +712,27 @@ test_that(
expect_true(dropTempView("dfView"))
})
-test_that("test cache, uncache and clearCache", {
- df <- read.json(jsonPath)
- createOrReplaceTempView(df, "table1")
- cacheTable("table1")
- uncacheTable("table1")
+test_that("test tableExists, cache, uncache and clearCache", {
+ schema <- structType(structField("name", "string"), structField("age", "integer"),
+ structField("height", "float"))
+ createTable("table1", source = "json", schema = schema)
+
+ cacheTable("default.table1")
+ uncacheTable("spark_catalog.default.table1")
clearCache()
- expect_true(dropTempView("table1"))
expect_error(uncacheTable("zxwtyswklpf"),
- "Error in uncacheTable : analysis error - Table or view not found: zxwtyswklpf")
+ "[TABLE_OR_VIEW_NOT_FOUND]*`zxwtyswklpf`*")
+
+ expect_true(tableExists("table1"))
+ expect_true(tableExists("default.table1"))
+ expect_true(tableExists("spark_catalog.default.table1"))
+
+ sql("DROP TABLE IF EXISTS spark_catalog.default.table1")
+
+ expect_false(tableExists("table1"))
+ expect_false(tableExists("default.table1"))
+ expect_false(tableExists("spark_catalog.default.table1"))
})
test_that("insertInto() on a registered table", {
@@ -1264,6 +1291,15 @@ test_that("drop column", {
df1 <- drop(df, df$age)
expect_equal(columns(df1), c("name", "age2"))
+ df1 <- drop(df, df$age, df$name)
+ expect_equal(columns(df1), c("age2"))
+
+ df1 <- drop(df, df$age, column("random"))
+ expect_equal(columns(df1), c("name", "age2"))
+
+ df1 <- drop(df, df$age, "random")
+ expect_equal(columns(df1), c("name", "age2"))
+
df$age2 <- NULL
expect_equal(columns(df), c("name", "age"))
df$age3 <- NULL
@@ -1342,7 +1378,7 @@ test_that("test HiveContext", {
schema <- structType(structField("name", "string"), structField("age", "integer"),
structField("height", "float"))
- createTable("people", source = "json", schema = schema)
+ createTable("spark_catalog.default.people", source = "json", schema = schema)
df <- read.df(jsonPathNa, "json", schema)
insertInto(df, "people")
expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16))
@@ -1568,6 +1604,16 @@ test_that("column functions", {
result <- collect(select(df, array_sort(df[[1]])))[[1]]
expect_equal(result, list(list(1L, 2L, 3L, NA), list(4L, 5L, 6L, NA, NA)))
+ result <- collect(select(
+ df,
+ array_sort(
+ df[[1]],
+ function(x, y) otherwise(
+ when(isNull(x), 1L), otherwise(when(isNull(y), -1L), cast(y - x, "integer"))
+ )
+ )
+ ))[[1]]
+ expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA)))
result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]]
expect_equal(result, list(list(3L, 2L, 1L, NA), list(6L, 5L, 4L, NA, NA)))
@@ -2967,6 +3013,32 @@ test_that("mutate(), transform(), rename() and names()", {
expect_match(tail(columns(newDF), 1L), "234567890", fixed = TRUE)
})
+test_that("unpivot / melt", {
+ df <- createDataFrame(data.frame(
+ id = 1:3, x = c(1, 3, 5), y = c(2, 4, 6), z = c(-1, 0, 1)
+ ))
+
+ result <- unpivot(df, "id", c("x", "y"), "var", "val")
+ expect_s4_class(result, "SparkDataFrame")
+ expect_equal(columns(result), c("id", "var", "val"))
+ expect_equal(count(distinct(select(result, "var"))), 2)
+
+ result <- unpivot(df, "id", NULL, "variable", "value")
+ expect_s4_class(result, "SparkDataFrame")
+ expect_equal(columns(result), c("id", "variable", "value"))
+ expect_equal(count(distinct(select(result, "variable"))), 3)
+
+ result <- melt(df, "id", c("x", "y"), "key", "value")
+ expect_s4_class(result, "SparkDataFrame")
+ expect_equal(columns(result), c("id", "key", "value"))
+ expect_equal(count(distinct(select(result, "key"))), 2)
+
+ result <- melt(df, "id", NULL, "key", "val")
+ expect_s4_class(result, "SparkDataFrame")
+ expect_equal(columns(result), c("id", "key", "val"))
+ expect_equal(count(distinct(select(result, "key"))), 3)
+})
+
test_that("read/write ORC files", {
setHiveContext(sc)
df <- read.df(jsonPath, "json")
@@ -3321,8 +3393,8 @@ test_that("approxQuantile() on a DataFrame", {
test_that("SQL error message is returned from JVM", {
retError <- tryCatch(sql("select * from blah"), error = function(e) e)
- expect_equal(grepl("Table or view not found", retError), TRUE)
- expect_equal(grepl("blah", retError), TRUE)
+ expect_equal(grepl("[TABLE_OR_VIEW_NOT_FOUND]", retError), TRUE)
+ expect_equal(grepl("`blah`", retError), TRUE)
})
irisDF <- suppressWarnings(createDataFrame(iris))
@@ -3411,6 +3483,8 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", {
"Length of type vector should match the number of columns for SparkDataFrame")
expect_error(coltypes(df) <- c("environment", "list"),
"Only atomic type is supported for column types")
+
+ dropTempView("dfView")
})
test_that("Method str()", {
@@ -3450,6 +3524,8 @@ test_that("Method str()", {
# Test utils:::str
expect_equal(capture.output(utils:::str(iris)), capture.output(str(iris)))
+
+ dropTempView("irisView")
})
test_that("Histogram", {
@@ -3911,15 +3987,16 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume
# It makes sure that we can omit path argument in write.df API and then it calls
# DataFrameWriter.save() without path.
expect_error(write.df(df, source = "csv"),
- "Error in save : illegal argument - Expected exactly one path to be specified")
+ paste("Error in save : org.apache.spark.SparkIllegalArgumentException:",
+ "Expected exactly one path to be specified"))
expect_error(write.json(df, jsonPath),
- "Error in json : analysis error - path file:.*already exists")
+ "Error in json : analysis error - \\[PATH_ALREADY_EXISTS\\].*")
expect_error(write.text(df, jsonPath),
- "Error in text : analysis error - path file:.*already exists")
+ "Error in text : analysis error - \\[PATH_ALREADY_EXISTS\\].*")
expect_error(write.orc(df, jsonPath),
- "Error in orc : analysis error - path file:.*already exists")
+ "Error in orc : analysis error - \\[PATH_ALREADY_EXISTS\\].*")
expect_error(write.parquet(df, jsonPath),
- "Error in parquet : analysis error - path file:.*already exists")
+ "Error in parquet : analysis error - \\[PATH_ALREADY_EXISTS\\].*")
expect_error(write.parquet(df, jsonPath, mode = 123), "mode should be character or omitted.")
# Arguments checking in R side.
@@ -3937,14 +4014,17 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume
# It makes sure that we can omit path argument in read.df API and then it calls
# DataFrameWriter.load() without path.
expect_error(read.df(source = "json"),
- paste("Error in load : analysis error - Unable to infer schema for JSON.",
- "It must be specified manually"))
- expect_error(read.df("arbitrary_path"), "Error in load : analysis error - Path does not exist")
- expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist")
- expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist")
- expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist")
+ "Error in load : analysis error - \\[UNABLE_TO_INFER_SCHEMA\\].*")
+ expect_error(read.df("arbitrary_path"),
+ "Error in load : analysis error - \\[PATH_NOT_FOUND\\].*")
+ expect_error(read.json("arbitrary_path"),
+ "Error in json : analysis error - \\[PATH_NOT_FOUND\\].*")
+ expect_error(read.text("arbitrary_path"),
+ "Error in text : analysis error - \\[PATH_NOT_FOUND\\].*")
+ expect_error(read.orc("arbitrary_path"),
+ "Error in orc : analysis error - \\[PATH_NOT_FOUND\\].*")
expect_error(read.parquet("arbitrary_path"),
- "Error in parquet : analysis error - Path does not exist")
+ "Error in parquet : analysis error - \\[PATH_NOT_FOUND\\].*")
# Arguments checking in R side.
expect_error(read.df(path = c(3)),
@@ -3963,14 +4043,14 @@ test_that("Specify a schema by using a DDL-formatted string when reading", {
expect_is(df1, "SparkDataFrame")
expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double")))
- expect_error(read.df(jsonPath, "json", "name stri"), "DataType stri is not supported.")
+ expect_error(read.df(jsonPath, "json", "name stri"), ".*Unsupported data type \"STRI\".*")
# Test loadDF with a user defined schema in a DDL-formatted string.
df2 <- loadDF(jsonPath, "json", "name STRING, age DOUBLE")
expect_is(df2, "SparkDataFrame")
expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double")))
- expect_error(loadDF(jsonPath, "json", "name stri"), "DataType stri is not supported.")
+ expect_error(loadDF(jsonPath, "json", "name stri"), ".*Unsupported data type \"STRI\".*")
})
test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", {
@@ -4011,22 +4091,45 @@ test_that("Collect on DataFrame when NAs exists at the top of a timestamp column
expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt"))
})
-test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases", {
+test_that("catalog APIs, listCatalogs, setCurrentCatalog, currentCatalog", {
+ expect_equal(currentCatalog(), "spark_catalog")
+ expect_error(setCurrentCatalog("spark_catalog"), NA)
+ expect_error(setCurrentCatalog("zxwtyswklpf"),
+ paste0("Error in setCurrentCatalog : ",
+ "org.apache.spark.sql.connector.catalog.CatalogNotFoundException: ",
+ "Catalog 'zxwtyswklpf' plugin class not found: ",
+ "spark.sql.catalog.zxwtyswklpf is not defined"))
+ catalogs <- collect(listCatalogs())
+})
+
+test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases, getDatabase", {
expect_equal(currentDatabase(), "default")
expect_error(setCurrentDatabase("default"), NA)
expect_error(setCurrentDatabase("zxwtyswklpf"),
- paste0("Error in setCurrentDatabase : analysis error - Database ",
- "'zxwtyswklpf' does not exist"))
+ "[SCHEMA_NOT_FOUND]*`zxwtyswklpf`*")
+
+ expect_true(databaseExists("default"))
+ expect_true(databaseExists("spark_catalog.default"))
+ expect_false(databaseExists("some_db"))
+ expect_false(databaseExists("spark_catalog.some_db"))
+
dbs <- collect(listDatabases())
- expect_equal(names(dbs), c("name", "description", "locationUri"))
+ expect_equal(names(dbs), c("name", "catalog", "description", "locationUri"))
expect_equal(which(dbs[, 1] == "default"), 1)
+
+ db <- getDatabase("spark_catalog.default")
+ expect_equal(db$name, "default")
+ expect_equal(db$catalog, "spark_catalog")
})
-test_that("catalog APIs, listTables, listColumns, listFunctions", {
+test_that("catalog APIs, listTables, getTable, listColumns, listFunctions, functionExists", {
tb <- listTables()
count <- count(tables())
+ expect_equal(nrow(listTables("default")), count)
+ expect_equal(nrow(listTables("spark_catalog.default")), count)
expect_equal(nrow(tb), count)
- expect_equal(colnames(tb), c("name", "database", "description", "tableType", "isTemporary"))
+ expect_equal(colnames(tb),
+ c("name", "catalog", "namespace", "description", "tableType", "isTemporary"))
createOrReplaceTempView(as.DataFrame(cars), "cars")
@@ -4035,7 +4138,7 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", {
tbs <- collect(tb)
expect_true(nrow(tbs[tbs$name == "cars", ]) > 0)
expect_error(listTables("bar"),
- "Error in listTables : no such database - Database 'bar' not found")
+ "[SCHEMA_NOT_FOUND]*`bar`*")
c <- listColumns("cars")
expect_equal(nrow(c), 2)
@@ -4043,18 +4146,48 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", {
c("name", "description", "dataType", "nullable", "isPartition", "isBucket"))
expect_equal(collect(c)[[1]][[1]], "speed")
expect_error(listColumns("zxwtyswklpf", "default"),
- paste("Error in listColumns : analysis error - Table",
- "'zxwtyswklpf' does not exist in database 'default'"))
+ "[TABLE_OR_VIEW_NOT_FOUND]*`spark_catalog`.`default`.`zxwtyswklpf`*")
f <- listFunctions()
expect_true(nrow(f) >= 200) # 250
expect_equal(colnames(f),
- c("name", "database", "description", "className", "isTemporary"))
- expect_equal(take(orderBy(f, "className"), 1)$className,
+ c("name", "catalog", "namespace", "description", "className", "isTemporary"))
+ expect_equal(take(orderBy(filter(f, "className IS NOT NULL"), "className"), 1)$className,
"org.apache.spark.sql.catalyst.expressions.Abs")
expect_error(listFunctions("zxwtyswklpf_db"),
- paste("Error in listFunctions : analysis error - Database",
- "'zxwtyswklpf_db' does not exist"))
+ "[SCHEMA_NOT_FOUND]*`zxwtyswklpf_db`*")
+
+ expect_true(functionExists("abs"))
+ expect_false(functionExists("aabbss"))
+
+ func0 <- getFunc("abs")
+ expect_equal(func0$name, "abs")
+ expect_equal(func0$className, "org.apache.spark.sql.catalyst.expressions.Abs")
+ expect_true(func0$isTemporary)
+
+ sql("CREATE FUNCTION func1 AS 'org.apache.spark.sql.catalyst.expressions.Add'")
+
+ func1 <- getFunc("spark_catalog.default.func1")
+ expect_equal(func1$name, "func1")
+ expect_equal(func1$catalog, "spark_catalog")
+ expect_equal(length(func1$namespace), 1)
+ expect_equal(func1$namespace[[1]], "default")
+ expect_equal(func1$className, "org.apache.spark.sql.catalyst.expressions.Add")
+ expect_false(func1$isTemporary)
+
+ expect_true(functionExists("func1"))
+ expect_true(functionExists("default.func1"))
+ expect_true(functionExists("spark_catalog.default.func1"))
+
+ expect_false(functionExists("func2"))
+ expect_false(functionExists("default.func2"))
+ expect_false(functionExists("spark_catalog.default.func2"))
+
+ sql("DROP FUNCTION func1")
+
+ expect_false(functionExists("func1"))
+ expect_false(functionExists("default.func1"))
+ expect_false(functionExists("spark_catalog.default.func1"))
# recoverPartitions does not work with temporary view
expect_error(recoverPartitions("cars"),
@@ -4063,7 +4196,26 @@ test_that("catalog APIs, listTables, listColumns, listFunctions", {
expect_error(refreshTable("cars"), NA)
expect_error(refreshByPath("/"), NA)
+ view <- getTable("cars")
+ expect_equal(view$name, "cars")
+ expect_equal(view$tableType, "TEMPORARY")
+ expect_true(view$isTemporary)
+
dropTempView("cars")
+
+ schema <- structType(structField("name", "string"), structField("age", "integer"),
+ structField("height", "float"))
+ createTable("default.people", source = "json", schema = schema)
+
+ tbl <- getTable("spark_catalog.default.people")
+ expect_equal(tbl$name, "people")
+ expect_equal(tbl$catalog, "spark_catalog")
+ expect_equal(length(tbl$namespace), 1)
+ expect_equal(tbl$namespace[[1]], "default")
+ expect_equal(tbl$tableType, "MANAGED")
+ expect_false(tbl$isTemporary)
+
+ sql("DROP TABLE IF EXISTS people")
})
test_that("assert_true, raise_error", {
@@ -4084,6 +4236,54 @@ test_that("assert_true, raise_error", {
expect_error(collect(select(filtered, raise_error(filtered$name))), "Justin")
})
+test_that("SPARK-41937: check class column for multi-class object works", {
+ .originalTimeZone <- Sys.getenv("TZ")
+ Sys.setenv(TZ = "")
+ temp_time <- as.POSIXlt("2015-03-11 12:13:04.043", tz = "")
+ sdf <- createDataFrame(
+ data.frame(x = temp_time + c(-1, 1, -1, 1, -1)),
+ schema = structType("x timestamp")
+ )
+ expect_warning(collect(filter(sdf, column("x") > temp_time)), NA)
+ expect_equal(collect(filter(sdf, column("x") > temp_time)), data.frame(x = temp_time + c(1, 1)))
+ expect_warning(collect(filter(sdf, contains(column("x"), temp_time + 5))), NA)
+ expect_warning(
+ collect(
+ mutate(
+ sdf,
+ newcol = otherwise(when(column("x") > lit(temp_time), temp_time), temp_time + 1)
+ )
+ ),
+ NA
+ )
+ expect_equal(
+ collect(
+ mutate(
+ sdf,
+ newcol = otherwise(when(column("x") > lit(temp_time), temp_time), temp_time + 1)
+ )
+ ),
+ data.frame(x = temp_time + c(-1, 1, -1, 1, -1), newcol = temp_time + c(1, 0, 1, 0, 1))
+ )
+ expect_error(
+ collect(fillna(sdf, temp_time)),
+ "value should be an integer, numeric, character or named list"
+ )
+ expect_error(
+ collect(fillna(sdf, list(x = temp_time))),
+ "value should be an integer, numeric or character"
+ )
+ expect_warning(
+ collect(mutate(sdf, x2 = ifelse(column("x") > temp_time, temp_time + 5, temp_time - 5))),
+ NA
+ )
+ expect_equal(
+ collect(mutate(sdf, x2 = ifelse(column("x") > temp_time, temp_time + 5, temp_time - 5))),
+ data.frame(x = temp_time + c(-1, 1, -1, 1, -1), x2 = temp_time + c(-5, 5, -5, 5, -5))
+ )
+ Sys.setenv(TZ = .originalTimeZone)
+})
+
compare_list <- function(list1, list2) {
# get testthat to show the diff by first making the 2 lists equal in length
expect_equal(length(list1), length(list2))
diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R
index 6f0d2aefee886..8804471e640cf 100644
--- a/R/pkg/tests/fulltests/test_streaming.R
+++ b/R/pkg/tests/fulltests/test_streaming.R
@@ -130,7 +130,7 @@ test_that("Specify a schema by using a DDL-formatted string when reading", {
stopQuery(q)
expect_error(read.stream(path = parquetPath, schema = "name stri"),
- "DataType stri is not supported.")
+ ".*Unsupported data type \"STRI\".*")
unlink(parquetPath)
})
@@ -140,8 +140,7 @@ test_that("Non-streaming DataFrame", {
expect_false(isStreaming(c))
expect_error(write.stream(c, "memory", queryName = "people", outputMode = "complete"),
- paste0(".*(writeStream : analysis error - 'writeStream' can be called only on ",
- "streaming Dataset/DataFrame).*"))
+ paste0("Error in writeStream : analysis error - \\[WRITE_STREAM_NOT_ALLOWED\\].*"))
})
test_that("Unsupported operation", {
diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R
index 35f9c9e7bb31e..4d263e5d76509 100644
--- a/R/pkg/tests/fulltests/test_utils.R
+++ b/R/pkg/tests/fulltests/test_utils.R
@@ -190,7 +190,7 @@ test_that("captureJVMException", {
error = function(e) {
captureJVMException(e, method)
}),
- "parse error - .*DataType unknown.*not supported.")
+ ".*Unsupported data type \"UNKNOWN\".*")
})
test_that("hashCode", {
diff --git a/R/run-tests.sh b/R/run-tests.sh
index 99b7438a80097..90a60eda03871 100755
--- a/R/run-tests.sh
+++ b/R/run-tests.sh
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
@@ -23,16 +23,16 @@ FAILED=0
LOGFILE=$FWDIR/unit-tests.out
rm -f $LOGFILE
-SPARK_AVRO_JAR_PATH=$(find $FWDIR/../external/avro/ -name "spark-avro*jar" -print | egrep -v "tests.jar|test-sources.jar|sources.jar|javadoc.jar")
+SPARK_AVRO_JAR_PATH=$(find $FWDIR/../connector/avro/ -name "spark-avro*jar" -print | egrep -v "tests.jar|test-sources.jar|sources.jar|javadoc.jar")
if [[ $(echo $SPARK_AVRO_JAR_PATH | wc -l) -eq 1 ]]; then
SPARK_JARS=$SPARK_AVRO_JAR_PATH
fi
if [ -z "$SPARK_JARS" ]; then
- SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
+ SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
else
- SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --jars $SPARK_JARS --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
+ SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --jars $SPARK_JARS --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
fi
FAILED=$((PIPESTATUS[0]||$FAILED))
diff --git a/README.md b/README.md
index dbc0f2ba87ead..310df41f4654b 100644
--- a/README.md
+++ b/README.md
@@ -9,9 +9,10 @@ and Structured Streaming for stream processing.
-[![GitHub Action Build](https://github.com/apache/spark/actions/workflows/build_and_test.yml/badge.svg?branch=master&event=push)](https://github.com/apache/spark/actions/workflows/build_and_test.yml?query=branch%3Amaster+event%3Apush)
+[![GitHub Actions Build](https://github.com/apache/spark/actions/workflows/build_main.yml/badge.svg)](https://github.com/apache/spark/actions/workflows/build_main.yml)
[![AppVeyor Build](https://img.shields.io/appveyor/ci/ApacheSoftwareFoundation/spark/master.svg?style=plastic&logo=appveyor)](https://ci.appveyor.com/project/ApacheSoftwareFoundation/spark)
[![PySpark Coverage](https://codecov.io/gh/apache/spark/branch/master/graph/badge.svg)](https://codecov.io/gh/apache/spark)
+[![PyPI Downloads](https://static.pepy.tech/personalized-badge/pyspark?period=month&units=international_system&left_color=black&right_color=orange&left_text=PyPI%20downloads)](https://pypi.org/project/pyspark/)
## Online Documentation
diff --git a/appveyor.yml b/appveyor.yml
index 53ef8527c6555..fdb247d5d4375 100644
--- a/appveyor.yml
+++ b/appveyor.yml
@@ -28,6 +28,7 @@ only_commits:
files:
- appveyor.yml
- dev/appveyor-install-dependencies.ps1
+ - build/spark-build-info.ps1
- R/
- sql/core/src/main/scala/org/apache/spark/sql/api/r/
- core/src/main/scala/org/apache/spark/api/r/
@@ -50,10 +51,12 @@ build_script:
# See SPARK-28759.
# Ideally we should check the tests related to Hive in SparkR as well (SPARK-31745).
- cmd: set SBT_MAVEN_PROFILES=-Psparkr
- - cmd: set SBT_OPTS=-Djna.nosys=true -Dfile.encoding=UTF-8 -Xms4096m -Xms4096m -XX:ReservedCodeCacheSize=128m
+ - cmd: set SBT_OPTS=-Djna.nosys=true -Dfile.encoding=UTF-8 -XX:ReservedCodeCacheSize=128m
+ - cmd: set JAVA_OPTS=-Xms4096m -Xms4096m
- cmd: sbt package
- cmd: set SBT_MAVEN_PROFILES=
- cmd: set SBT_OPTS=
+ - cmd: set JAVA_OPTS=
environment:
NOT_CRAN: true
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 32126a5e13820..b09ffdad3ff3e 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent_2.12
- 3.3.1
+ 3.4.1../pom.xml
@@ -152,6 +152,16 @@
+
+ connect
+
+
+ org.apache.spark
+ spark-connect_${scala.binary.version}
+ ${project.version}
+
+
+ kubernetes
diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh
index ad31bd1e7b7ab..a137a2fba52ee 100755
--- a/bin/docker-image-tool.sh
+++ b/bin/docker-image-tool.sh
@@ -181,7 +181,7 @@ function build {
error "Failed to build Spark JVM Docker image, please refer to Docker build output for details."
fi
if [ "${CROSS_BUILD}" != "false" ]; then
- (cd $(img_ctx_dir base) && docker buildx build $ARCHS $NOCACHEARG "${BUILD_ARGS[@]}" --push \
+ (cd $(img_ctx_dir base) && docker buildx build $ARCHS $NOCACHEARG "${BUILD_ARGS[@]}" --push --provenance=false \
-t $(image_ref spark) \
-f "$BASEDOCKERFILE" .)
fi
@@ -194,7 +194,7 @@ function build {
error "Failed to build PySpark Docker image, please refer to Docker build output for details."
fi
if [ "${CROSS_BUILD}" != "false" ]; then
- (cd $(img_ctx_dir pyspark) && docker buildx build $ARCHS $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" --push \
+ (cd $(img_ctx_dir pyspark) && docker buildx build $ARCHS $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" --push --provenance=false \
-t $(image_ref spark-py) \
-f "$PYDOCKERFILE" .)
fi
@@ -208,7 +208,7 @@ function build {
error "Failed to build SparkR Docker image, please refer to Docker build output for details."
fi
if [ "${CROSS_BUILD}" != "false" ]; then
- (cd $(img_ctx_dir sparkr) && docker buildx build $ARCHS $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" --push \
+ (cd $(img_ctx_dir sparkr) && docker buildx build $ARCHS $NOCACHEARG "${BINDING_BUILD_ARGS[@]}" --push --provenance=false \
-t $(image_ref spark-r) \
-f "$RDOCKERFILE" .)
fi
@@ -233,7 +233,6 @@ Commands:
Options:
-f file (Optional) Dockerfile to build for JVM based Jobs. By default builds the Dockerfile shipped with Spark.
- For Java 17, use `-f kubernetes/dockerfiles/spark/Dockerfile.java17`
-p file (Optional) Dockerfile to build for PySpark Jobs. Builds Python dependencies and ships with Spark.
Skips building PySpark docker image if not specified.
-R file (Optional) Dockerfile to build for SparkR Jobs. Builds R dependencies and ships with Spark.
@@ -262,25 +261,21 @@ Examples:
$0 -m -t testing build
- Build PySpark docker image
- $0 -r docker.io/myrepo -t v2.3.0 -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile build
+ $0 -r docker.io/myrepo -t v3.4.0 -p kubernetes/dockerfiles/spark/bindings/python/Dockerfile build
- - Build and push image with tag "v2.3.0" to docker.io/myrepo
- $0 -r docker.io/myrepo -t v2.3.0 build
- $0 -r docker.io/myrepo -t v2.3.0 push
+ - Build and push image with tag "v3.4.0" to docker.io/myrepo
+ $0 -r docker.io/myrepo -t v3.4.0 build
+ $0 -r docker.io/myrepo -t v3.4.0 push
- - Build and push Java11-based image with tag "v3.0.0" to docker.io/myrepo
- $0 -r docker.io/myrepo -t v3.0.0 -b java_image_tag=11-jre-slim build
- $0 -r docker.io/myrepo -t v3.0.0 push
+ - Build and push Java11-based image with tag "v3.4.0" to docker.io/myrepo
+ $0 -r docker.io/myrepo -t v3.4.0 -b java_image_tag=11-jre build
+ $0 -r docker.io/myrepo -t v3.4.0 push
- - Build and push Java11-based image for multiple archs to docker.io/myrepo
- $0 -r docker.io/myrepo -t v3.0.0 -X -b java_image_tag=11-jre-slim build
+ - Build and push image for multiple archs to docker.io/myrepo
+ $0 -r docker.io/myrepo -t v3.4.0 -X build
# Note: buildx, which does cross building, needs to do the push during build
# So there is no separate push step with -X
- - Build and push Java17-based image with tag "v3.3.0" to docker.io/myrepo
- $0 -r docker.io/myrepo -t v3.3.0 -f kubernetes/dockerfiles/spark/Dockerfile.java17 build
- $0 -r docker.io/myrepo -t v3.3.0 push
-
EOF
}
diff --git a/bin/pyspark b/bin/pyspark
index 21a514e5e2c4a..1ae28b1f507cd 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -50,7 +50,7 @@ export PYSPARK_DRIVER_PYTHON_OPTS
# Add the PySpark classes to the Python path:
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
-export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.9.5-src.zip:$PYTHONPATH"
+export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.9.7-src.zip:$PYTHONPATH"
# Load the PySpark shell.py script when ./pyspark is used interactively:
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index eec02a406b680..232813b4ffdd6 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
)
set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
-set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.9.5-src.zip;%PYTHONPATH%
+set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.9.7-src.zip;%PYTHONPATH%
set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py
diff --git a/bin/spark-class b/bin/spark-class
index c1461a7712289..fc343ca29fddd 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -77,7 +77,8 @@ set +o posix
CMD=()
DELIM=$'\n'
CMD_START_FLAG="false"
-while IFS= read -d "$DELIM" -r ARG; do
+while IFS= read -d "$DELIM" -r _ARG; do
+ ARG=${_ARG//$'\r'}
if [ "$CMD_START_FLAG" == "true" ]; then
CMD+=("$ARG")
else
diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd
index 68b271d1d05d9..800ec0c02c22f 100755
--- a/bin/spark-class2.cmd
+++ b/bin/spark-class2.cmd
@@ -69,6 +69,8 @@ rem SPARK-28302: %RANDOM% would return the same number if we call it instantly a
rem so we should make it sure to generate unique file to avoid process collision of writing into
rem the same file concurrently.
if exist %LAUNCHER_OUTPUT% goto :gen
+rem unset SHELL to indicate non-bash environment to launcher/Main
+set SHELL=
"%RUNNER%" -Xmx128m -cp "%LAUNCH_CLASSPATH%" org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT%
for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do (
set SPARK_CMD=%%i
diff --git a/bin/spark-connect-shell b/bin/spark-connect-shell
new file mode 100755
index 0000000000000..9026c81e70d81
--- /dev/null
+++ b/bin/spark-connect-shell
@@ -0,0 +1,27 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# The shell script to start a spark-shell with spark connect enabled.
+
+if [ -z "${SPARK_HOME}" ]; then
+ source "$(dirname "$0")"/find-spark-home
+fi
+
+# This requires building the spark with `-Pconnect`, e,g, `build/sbt -Pconnect package`
+exec "${SPARK_HOME}"/bin/spark-shell --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin "$@"
\ No newline at end of file
diff --git a/bin/sparkR b/bin/sparkR
index 29ab10df8ab6d..8ecc755839fe3 100755
--- a/bin/sparkR
+++ b/bin/sparkR
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
diff --git a/binder/postBuild b/binder/postBuild
index 733eafe175ef0..70ae23b393707 100644
--- a/binder/postBuild
+++ b/binder/postBuild
@@ -1,4 +1,4 @@
-#!/bin/bash
+#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
@@ -32,11 +32,24 @@ else
SPECIFIER="<="
fi
-pip install plotly "pyspark[sql,ml,mllib,pandas_on_spark]$SPECIFIER$VERSION"
+if [[ ! $VERSION < "3.4.0" ]]; then
+ pip install plotly "pandas<2.0.0" "pyspark[sql,ml,mllib,pandas_on_spark,connect]$SPECIFIER$VERSION"
+else
+ pip install plotly "pandas<2.0.0" "pyspark[sql,ml,mllib,pandas_on_spark]$SPECIFIER$VERSION"
+fi
# Set 'PYARROW_IGNORE_TIMEZONE' to surpress warnings from PyArrow.
echo "export PYARROW_IGNORE_TIMEZONE=1" >> ~/.profile
+# Add sbin to PATH to run `start-connect-server.sh`.
+SPARK_HOME=$(python -c "from pyspark.find_spark_home import _find_spark_home; print(_find_spark_home())")
+echo "export PATH=${PATH}:${SPARK_HOME}/sbin" >> ~/.profile
+echo "export SPARK_HOME=${SPARK_HOME}" >> ~/.profile
+
+# Add Spark version to env for running command dynamically based on Spark version.
+SPARK_VERSION=$(python -c "import pyspark; print(pyspark.__version__)")
+echo "export SPARK_VERSION=${SPARK_VERSION}" >> ~/.profile
+
# Surpress warnings from Spark jobs, and UI progress bar.
mkdir -p ~/.ipython/profile_default/startup
echo """from pyspark.sql import SparkSession
diff --git a/build/mvn b/build/mvn
index 4989c2d7efd62..aee9358fe44c6 100755
--- a/build/mvn
+++ b/build/mvn
@@ -36,7 +36,7 @@ _DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# Preserve the calling directory
_CALLING_DIR="$(pwd)"
# Options used during compilation
-_COMPILE_JVM_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=1g -Xss128m"
+_COMPILE_JVM_OPTS="-Xss128m -Xmx4g -XX:ReservedCodeCacheSize=128m"
# Installs any application tarball given a URL, the expected tarball name,
# and, optionally, a checkable binary path to determine if the binary has
@@ -119,7 +119,7 @@ install_mvn() {
if [ "$MVN_BIN" ]; then
local MVN_DETECTED_VERSION="$(mvn --version | head -n1 | awk '{print $3}')"
fi
- if [ $(version $MVN_DETECTED_VERSION) -lt $(version $MVN_VERSION) ]; then
+ if [ $(version $MVN_DETECTED_VERSION) -ne $(version $MVN_VERSION) ]; then
local MVN_TARBALL="apache-maven-${MVN_VERSION}-bin.tar.gz"
local FILE_PATH="maven/maven-3/${MVN_VERSION}/binaries/${MVN_TARBALL}"
local APACHE_MIRROR=${APACHE_MIRROR:-'https://www.apache.org/dyn/closer.lua'}
@@ -180,6 +180,13 @@ export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"}
echo "Using \`mvn\` from path: $MVN_BIN" 1>&2
+if [ ! -z "${SPARK_LOCAL_HOSTNAME}" ]; then
+ echo "Using SPARK_LOCAL_HOSTNAME=$SPARK_LOCAL_HOSTNAME" 1>&2
+fi
+if [ ! -z "${SPARK_LOCAL_IP}" ]; then
+ echo "Using SPARK_LOCAL_IP=$SPARK_LOCAL_IP" 1>&2
+fi
+
# call the `mvn` command as usual
# SPARK-25854
"${MVN_BIN}" "$@"
diff --git a/build/sbt b/build/sbt
index 843d2a026ed64..db9d3b345ff6f 100755
--- a/build/sbt
+++ b/build/sbt
@@ -133,6 +133,13 @@ saveSttySettings() {
fi
}
+if [ ! -z "${SPARK_LOCAL_HOSTNAME}" ]; then
+ echo "Using SPARK_LOCAL_HOSTNAME=$SPARK_LOCAL_HOSTNAME" 1>&2
+fi
+if [ ! -z "${SPARK_LOCAL_IP}" ]; then
+ echo "Using SPARK_LOCAL_IP=$SPARK_LOCAL_IP" 1>&2
+fi
+
saveSttySettings
trap onExit INT
diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash
index 8fb6672bddc4d..01ba6b929f922 100755
--- a/build/sbt-launch-lib.bash
+++ b/build/sbt-launch-lib.bash
@@ -183,8 +183,8 @@ run() {
# run sbt
execRunner "$java_cmd" \
- ${SBT_OPTS:-$default_sbt_opts} \
$(get_mem_opts $sbt_mem) \
+ ${SBT_OPTS:-$default_sbt_opts} \
${java_opts} \
${java_args[@]} \
-jar "$sbt_jar" \
diff --git a/build/spark-build-info b/build/spark-build-info
index eb0e3d730e23e..4a4ff9169b3fa 100755
--- a/build/spark-build-info
+++ b/build/spark-build-info
@@ -24,7 +24,7 @@
RESOURCE_DIR="$1"
mkdir -p "$RESOURCE_DIR"
-SPARK_BUILD_INFO="${RESOURCE_DIR}"/spark-version-info.properties
+SPARK_BUILD_INFO="${RESOURCE_DIR%/}"/spark-version-info.properties
echo_build_properties() {
echo version=$1
@@ -33,6 +33,7 @@ echo_build_properties() {
echo branch=$(git rev-parse --abbrev-ref HEAD)
echo date=$(date -u +%Y-%m-%dT%H:%M:%SZ)
echo url=$(git config --get remote.origin.url | sed 's|https://\(.*\)@\(.*\)|https://\2|')
+ echo docroot=https://spark.apache.org/docs/latest
}
echo_build_properties $2 > "$SPARK_BUILD_INFO"
diff --git a/build/spark-build-info.ps1 b/build/spark-build-info.ps1
new file mode 100644
index 0000000000000..43db8823340c6
--- /dev/null
+++ b/build/spark-build-info.ps1
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This script generates the build info for spark and places it into the spark-version-info.properties file.
+# Arguments:
+# ResourceDir - The target directory where properties file would be created. [./core/target/extra-resources]
+# SparkVersion - The current version of spark
+
+param(
+ # The resource directory.
+ [Parameter(Position = 0)]
+ [String]
+ $ResourceDir,
+
+ # The Spark version.
+ [Parameter(Position = 1)]
+ [String]
+ $SparkVersion
+)
+
+$null = New-Item -Type Directory -Force $ResourceDir
+$SparkBuildInfoPath = $ResourceDir.TrimEnd('\').TrimEnd('/') + '\spark-version-info.properties'
+
+$SparkBuildInfoContent =
+"version=$SparkVersion
+user=$($Env:USERNAME)
+revision=$(git rev-parse HEAD)
+branch=$(git rev-parse --abbrev-ref HEAD)
+date=$([DateTime]::UtcNow | Get-Date -UFormat +%Y-%m-%dT%H:%M:%SZ)
+url=$(git config --get remote.origin.url)"
+
+Set-Content -Path $SparkBuildInfoPath -Value $SparkBuildInfoContent
diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml
index 21bf56094503b..bb5467aa0e7a8 100644
--- a/common/kvstore/pom.xml
+++ b/common/kvstore/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.12
- 3.3.1
+ 3.4.1../../pom.xml
@@ -89,7 +89,7 @@
org.apache.logging.log4j
- log4j-slf4j-impl
+ log4j-slf4j2-impltest
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java
index 431c7e42774e4..a353a53d4b8d7 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/InMemoryStore.java
@@ -468,11 +468,6 @@ public T next() {
return iter.next();
}
- @Override
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
@Override
public List next(int max) {
List list = new ArrayList<>(max);
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java
index ff99d052cf7a2..02dd73e1a2f27 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVStoreSerializer.java
@@ -49,7 +49,7 @@ public KVStoreSerializer() {
this.mapper = new ObjectMapper();
}
- public final byte[] serialize(Object o) throws Exception {
+ public byte[] serialize(Object o) throws Exception {
if (o instanceof String) {
return ((String) o).getBytes(UTF_8);
} else {
@@ -62,7 +62,7 @@ public final byte[] serialize(Object o) throws Exception {
}
@SuppressWarnings("unchecked")
- public final T deserialize(byte[] data, Class klass) throws Exception {
+ public T deserialize(byte[] data, Class klass) throws Exception {
if (klass.equals(String.class)) {
return (T) new String(data, UTF_8);
} else {
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java
index a7e5831846ad4..a15d07cf59958 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java
@@ -48,7 +48,6 @@ public KVTypeInfo(Class> type) {
checkIndex(idx, indices);
f.setAccessible(true);
indices.put(idx.value(), idx);
- f.setAccessible(true);
accessors.put(idx.value(), new FieldAccessor(f));
}
}
@@ -61,7 +60,6 @@ public KVTypeInfo(Class> type) {
"Annotated method %s::%s should not have any parameters.", type.getName(), m.getName());
m.setAccessible(true);
indices.put(idx.value(), idx);
- m.setAccessible(true);
accessors.put(idx.value(), new MethodAccessor(m));
}
}
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java
index 6b28373a48065..b50906e2cbac4 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java
@@ -270,10 +270,14 @@ public boolean removeAllByIndexValues(
KVStoreView view = view(klass).index(index);
for (Object indexValue : indexValues) {
- for (T value: view.first(indexValue).last(indexValue)) {
- Object itemKey = naturalIndex.getValue(value);
- delete(klass, itemKey);
- removed = true;
+ try (KVStoreIterator iterator =
+ view.first(indexValue).last(indexValue).closeableIterator()) {
+ while (iterator.hasNext()) {
+ T value = iterator.next();
+ Object itemKey = naturalIndex.getValue(value);
+ delete(klass, itemKey);
+ removed = true;
+ }
}
}
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java
index e8fb4fac5ba17..35d0c6065fb0f 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java
@@ -143,11 +143,6 @@ public T next() {
}
}
- @Override
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
@Override
public List next(int max) {
List list = new ArrayList<>(max);
@@ -159,6 +154,8 @@ public List next(int max) {
@Override
public boolean skip(long n) {
+ if (closed) return false;
+
long skipped = 0;
while (skipped < n) {
if (next != null) {
@@ -189,6 +186,7 @@ public synchronized void close() throws IOException {
if (!closed) {
it.close();
closed = true;
+ next = null;
}
}
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java
index 7674bc52dc750..d328e5c79d341 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java
@@ -303,10 +303,14 @@ public boolean removeAllByIndexValues(
KVStoreView view = view(klass).index(index);
for (Object indexValue : indexValues) {
- for (T value: view.first(indexValue).last(indexValue)) {
- Object itemKey = naturalIndex.getValue(value);
- delete(klass, itemKey);
- removed = true;
+ try (KVStoreIterator iterator =
+ view.first(indexValue).last(indexValue).closeableIterator()) {
+ while (iterator.hasNext()) {
+ T value = iterator.next();
+ Object itemKey = naturalIndex.getValue(value);
+ delete(klass, itemKey);
+ removed = true;
+ }
}
}
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java
index 1db47f4dad00a..2b12fddef6583 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java
@@ -134,11 +134,6 @@ public T next() {
}
}
- @Override
- public void remove() {
- throw new UnsupportedOperationException();
- }
-
@Override
public List next(int max) {
List list = new ArrayList<>(max);
@@ -150,6 +145,8 @@ public List next(int max) {
@Override
public boolean skip(long n) {
+ if(closed) return false;
+
long skipped = 0;
while (skipped < n) {
if (next != null) {
@@ -183,6 +180,7 @@ public synchronized void close() throws IOException {
if (!closed) {
it.close();
closed = true;
+ next = null;
}
}
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java
index ab1e27285853e..223f3f93a8790 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorSuite.java
@@ -490,11 +490,15 @@ private void compareLists(Iterable> expected, List> actual) {
}
private KVStoreView view() throws Exception {
+ // SPARK-38896: this `view` will be closed in
+ // the `collect(KVStoreView view)` method.
return db.view(CustomType1.class);
}
private List collect(KVStoreView view) throws Exception {
- return Arrays.asList(Iterables.toArray(view, CustomType1.class));
+ try (KVStoreIterator iterator = view.closeableIterator()) {
+ return Lists.newArrayList(iterator);
+ }
}
private List sortBy(Comparator comp) {
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java
index 35656fb12238a..b2acd1ae15b16 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java
@@ -34,24 +34,14 @@ public void testObjectWriteReadDelete() throws Exception {
t.id = "id";
t.name = "name";
- try {
- store.read(CustomType1.class, t.key);
- fail("Expected exception for non-existent object.");
- } catch (NoSuchElementException nsee) {
- // Expected.
- }
+ assertThrows(NoSuchElementException.class, () -> store.read(CustomType1.class, t.key));
store.write(t);
assertEquals(t, store.read(t.getClass(), t.key));
assertEquals(1L, store.count(t.getClass()));
store.delete(t.getClass(), t.key);
- try {
- store.read(t.getClass(), t.key);
- fail("Expected exception for deleted object.");
- } catch (NoSuchElementException nsee) {
- // Expected.
- }
+ assertThrows(NoSuchElementException.class, () -> store.read(t.getClass(), t.key));
}
@Test
@@ -78,12 +68,7 @@ public void testMultipleObjectWriteReadDelete() throws Exception {
store.delete(t1.getClass(), t1.key);
assertEquals(t2, store.read(t2.getClass(), t2.key));
store.delete(t2.getClass(), t2.key);
- try {
- store.read(t2.getClass(), t2.key);
- fail("Expected exception for deleted object.");
- } catch (NoSuchElementException nsee) {
- // Expected.
- }
+ assertThrows(NoSuchElementException.class, () -> store.read(t2.getClass(), t2.key));
}
@Test
@@ -159,25 +144,25 @@ public void testRemoveAll() throws Exception {
assertEquals(9, store.count(ArrayKeyIndexType.class));
// Try removing non-existing keys
- assert(!store.removeAllByIndexValues(
+ assertFalse(store.removeAllByIndexValues(
ArrayKeyIndexType.class,
KVIndex.NATURAL_INDEX_NAME,
ImmutableSet.of(new int[] {10, 10, 10}, new int[] { 3, 3, 3 })));
assertEquals(9, store.count(ArrayKeyIndexType.class));
- assert(store.removeAllByIndexValues(
+ assertTrue(store.removeAllByIndexValues(
ArrayKeyIndexType.class,
KVIndex.NATURAL_INDEX_NAME,
ImmutableSet.of(new int[] {0, 0, 0}, new int[] { 2, 2, 2 })));
assertEquals(7, store.count(ArrayKeyIndexType.class));
- assert(store.removeAllByIndexValues(
+ assertTrue(store.removeAllByIndexValues(
ArrayKeyIndexType.class,
"id",
ImmutableSet.of(new String [] { "things" })));
assertEquals(4, store.count(ArrayKeyIndexType.class));
- assert(store.removeAllByIndexValues(
+ assertTrue(store.removeAllByIndexValues(
ArrayKeyIndexType.class,
"id",
ImmutableSet.of(new String [] { "more things" })));
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBBenchmark.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBBenchmark.java
index f2a91f916a309..9082e1887bf85 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBBenchmark.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBBenchmark.java
@@ -197,9 +197,15 @@ private void iterate(KVStoreView> view, String name) throws Exception {
}
}
- while (it.hasNext()) {
- try(Timer.Context ctx = iter.time()) {
- it.next();
+ try {
+ while (it.hasNext()) {
+ try (Timer.Context ctx = iter.time()) {
+ it.next();
+ }
+ }
+ } finally {
+ if (it != null) {
+ it.close();
}
}
}
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java
index c43c9b171f5a4..86f65e9be895f 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java
@@ -22,6 +22,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
+import java.util.Spliterators;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
@@ -71,36 +72,21 @@ public void testReopenAndVersionCheckDb() throws Exception {
db.close();
db = null;
- try {
- db = new LevelDB(dbpath);
- fail("Should have failed version check.");
- } catch (UnsupportedStoreVersionException e) {
- // Expected.
- }
+ assertThrows(UnsupportedStoreVersionException.class, () -> db = new LevelDB(dbpath));
}
@Test
public void testObjectWriteReadDelete() throws Exception {
CustomType1 t = createCustomType1(1);
- try {
- db.read(CustomType1.class, t.key);
- fail("Expected exception for non-existent object.");
- } catch (NoSuchElementException nsee) {
- // Expected.
- }
+ assertThrows(NoSuchElementException.class, () -> db.read(CustomType1.class, t.key));
db.write(t);
assertEquals(t, db.read(t.getClass(), t.key));
assertEquals(1L, db.count(t.getClass()));
db.delete(t.getClass(), t.key);
- try {
- db.read(t.getClass(), t.key);
- fail("Expected exception for deleted object.");
- } catch (NoSuchElementException nsee) {
- // Expected.
- }
+ assertThrows(NoSuchElementException.class, () -> db.read(t.getClass(), t.key));
// Look into the actual DB and make sure that all the keys related to the type have been
// removed.
@@ -251,13 +237,14 @@ public void testSkip() throws Exception {
db.write(createCustomType1(i));
}
- KVStoreIterator it = db.view(CustomType1.class).closeableIterator();
- assertTrue(it.hasNext());
- assertTrue(it.skip(5));
- assertEquals("key5", it.next().key);
- assertTrue(it.skip(3));
- assertEquals("key9", it.next().key);
- assertFalse(it.hasNext());
+ try (KVStoreIterator it = db.view(CustomType1.class).closeableIterator()) {
+ assertTrue(it.hasNext());
+ assertTrue(it.skip(5));
+ assertEquals("key5", it.next().key);
+ assertTrue(it.skip(3));
+ assertEquals("key9", it.next().key);
+ assertFalse(it.hasNext());
+ }
}
@Test
@@ -272,12 +259,15 @@ public void testNegativeIndexValues() throws Exception {
}
});
- List results = StreamSupport
- .stream(db.view(CustomType1.class).index("int").spliterator(), false)
- .map(e -> e.num)
- .collect(Collectors.toList());
+ try (KVStoreIterator iterator =
+ db.view(CustomType1.class).index("int").closeableIterator()) {
+ List results = StreamSupport
+ .stream(Spliterators.spliteratorUnknownSize(iterator, 0), false)
+ .map(e -> e.num)
+ .collect(Collectors.toList());
- assertEquals(expected, results);
+ assertEquals(expected, results);
+ }
}
@Test
@@ -315,6 +305,84 @@ public void testCloseLevelDBIterator() throws Exception {
assertTrue(!dbPathForCloseTest.exists());
}
+ @Test
+ public void testHasNextAfterIteratorClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // iter should be true
+ assertTrue(iter.hasNext());
+ // close iter
+ iter.close();
+ // iter.hasNext should be false after iter close
+ assertFalse(iter.hasNext());
+ }
+
+ @Test
+ public void testHasNextAfterDBClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // iter should be true
+ assertTrue(iter.hasNext());
+ // close db
+ db.close();
+ // iter.hasNext should be false after db close
+ assertFalse(iter.hasNext());
+ }
+
+ @Test
+ public void testNextAfterIteratorClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // iter should be true
+ assertTrue(iter.hasNext());
+ // close iter
+ iter.close();
+ // iter.next should throw NoSuchElementException after iter close
+ assertThrows(NoSuchElementException.class, iter::next);
+ }
+
+ @Test
+ public void testNextAfterDBClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // iter should be true
+ assertTrue(iter.hasNext());
+ // close db
+ iter.close();
+ // iter.next should throw NoSuchElementException after db close
+ assertThrows(NoSuchElementException.class, iter::next);
+ }
+
+ @Test
+ public void testSkipAfterIteratorClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // close iter
+ iter.close();
+ // skip should always return false after iter close
+ assertFalse(iter.skip(0));
+ assertFalse(iter.skip(1));
+ }
+
+ @Test
+ public void testSkipAfterDBClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // iter should be true
+ assertTrue(iter.hasNext());
+ // close db
+ db.close();
+ // skip should always return false after db close
+ assertFalse(iter.skip(0));
+ assertFalse(iter.skip(1));
+ }
+
private CustomType1 createCustomType1(int i) {
CustomType1 t = new CustomType1();
t.key = "key" + i;
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBTypeInfoSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBTypeInfoSuite.java
index 38db3bedaef6a..0359e11404cd4 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBTypeInfoSuite.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBTypeInfoSuite.java
@@ -43,34 +43,40 @@ public void testIndexAnnotation() throws Exception {
assertEquals(t1.child, ti.getIndexValue("child", t1));
}
- @Test(expected = IllegalArgumentException.class)
- public void testNoNaturalIndex() throws Exception {
- newTypeInfo(NoNaturalIndex.class);
+ @Test
+ public void testNoNaturalIndex() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(NoNaturalIndex.class));
}
- @Test(expected = IllegalArgumentException.class)
- public void testNoNaturalIndex2() throws Exception {
- newTypeInfo(NoNaturalIndex2.class);
+ @Test
+ public void testNoNaturalIndex2() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(NoNaturalIndex2.class));
}
- @Test(expected = IllegalArgumentException.class)
- public void testDuplicateIndex() throws Exception {
- newTypeInfo(DuplicateIndex.class);
+ @Test
+ public void testDuplicateIndex() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(DuplicateIndex.class));
}
- @Test(expected = IllegalArgumentException.class)
- public void testEmptyIndexName() throws Exception {
- newTypeInfo(EmptyIndexName.class);
+ @Test
+ public void testEmptyIndexName() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(EmptyIndexName.class));
}
- @Test(expected = IllegalArgumentException.class)
- public void testIllegalIndexName() throws Exception {
- newTypeInfo(IllegalIndexName.class);
+ @Test
+ public void testIllegalIndexName() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(IllegalIndexName.class));
}
- @Test(expected = IllegalArgumentException.class)
- public void testIllegalIndexMethod() throws Exception {
- newTypeInfo(IllegalIndexMethod.class);
+ @Test
+ public void testIllegalIndexMethod() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(IllegalIndexMethod.class));
}
@Test
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBBenchmark.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBBenchmark.java
index 4517a47b32f6b..25930bb1013d9 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBBenchmark.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBBenchmark.java
@@ -196,10 +196,15 @@ private void iterate(KVStoreView> view, String name) throws Exception {
}
}
}
-
- while (it.hasNext()) {
- try(Timer.Context ctx = iter.time()) {
- it.next();
+ try {
+ while (it.hasNext()) {
+ try (Timer.Context ctx = iter.time()) {
+ it.next();
+ }
+ }
+ } finally {
+ if (it != null) {
+ it.close();
}
}
}
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBIteratorSuite.java
index d4bfc7e0413ab..5450f6531d60c 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBIteratorSuite.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBIteratorSuite.java
@@ -20,11 +20,8 @@
import java.io.File;
import org.apache.commons.io.FileUtils;
-import org.apache.commons.lang3.SystemUtils;
import org.junit.AfterClass;
-import static org.junit.Assume.assumeFalse;
-
public class RocksDBIteratorSuite extends DBIteratorSuite {
private static File dbpath;
@@ -42,7 +39,6 @@ public static void cleanup() throws Exception {
@Override
protected KVStore createStore() throws Exception {
- assumeFalse(SystemUtils.IS_OS_MAC_OSX && SystemUtils.OS_ARCH.equals("aarch64"));
dbpath = File.createTempFile("test.", ".rdb");
dbpath.delete();
db = new RocksDB(dbpath);
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java
index cd18d227cba72..602ab2d6881a3 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java
@@ -22,19 +22,18 @@
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
+import java.util.Spliterators;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import com.google.common.collect.ImmutableSet;
import org.apache.commons.io.FileUtils;
-import org.apache.commons.lang3.SystemUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.rocksdb.RocksIterator;
import static org.junit.Assert.*;
-import static org.junit.Assume.assumeFalse;
public class RocksDBSuite {
@@ -53,7 +52,6 @@ public void cleanup() throws Exception {
@Before
public void setup() throws Exception {
- assumeFalse(SystemUtils.IS_OS_MAC_OSX && SystemUtils.OS_ARCH.equals("aarch64"));
dbpath = File.createTempFile("test.", ".rdb");
dbpath.delete();
db = new RocksDB(dbpath);
@@ -72,36 +70,21 @@ public void testReopenAndVersionCheckDb() throws Exception {
db.close();
db = null;
- try {
- db = new RocksDB(dbpath);
- fail("Should have failed version check.");
- } catch (UnsupportedStoreVersionException e) {
- // Expected.
- }
+ assertThrows(UnsupportedStoreVersionException.class, () -> db = new RocksDB(dbpath));
}
@Test
public void testObjectWriteReadDelete() throws Exception {
CustomType1 t = createCustomType1(1);
- try {
- db.read(CustomType1.class, t.key);
- fail("Expected exception for non-existent object.");
- } catch (NoSuchElementException nsee) {
- // Expected.
- }
+ assertThrows(NoSuchElementException.class, () -> db.read(CustomType1.class, t.key));
db.write(t);
assertEquals(t, db.read(t.getClass(), t.key));
assertEquals(1L, db.count(t.getClass()));
db.delete(t.getClass(), t.key);
- try {
- db.read(t.getClass(), t.key);
- fail("Expected exception for deleted object.");
- } catch (NoSuchElementException nsee) {
- // Expected.
- }
+ assertThrows(NoSuchElementException.class, () -> db.read(t.getClass(), t.key));
// Look into the actual DB and make sure that all the keys related to the type have been
// removed.
@@ -252,13 +235,14 @@ public void testSkip() throws Exception {
db.write(createCustomType1(i));
}
- KVStoreIterator it = db.view(CustomType1.class).closeableIterator();
- assertTrue(it.hasNext());
- assertTrue(it.skip(5));
- assertEquals("key5", it.next().key);
- assertTrue(it.skip(3));
- assertEquals("key9", it.next().key);
- assertFalse(it.hasNext());
+ try (KVStoreIterator it = db.view(CustomType1.class).closeableIterator()) {
+ assertTrue(it.hasNext());
+ assertTrue(it.skip(5));
+ assertEquals("key5", it.next().key);
+ assertTrue(it.skip(3));
+ assertEquals("key9", it.next().key);
+ assertFalse(it.hasNext());
+ }
}
@Test
@@ -273,12 +257,15 @@ public void testNegativeIndexValues() throws Exception {
}
});
- List results = StreamSupport
- .stream(db.view(CustomType1.class).index("int").spliterator(), false)
- .map(e -> e.num)
- .collect(Collectors.toList());
+ try (KVStoreIterator iterator =
+ db.view(CustomType1.class).index("int").closeableIterator()) {
+ List results = StreamSupport
+ .stream(Spliterators.spliteratorUnknownSize(iterator, 0), false)
+ .map(e -> e.num)
+ .collect(Collectors.toList());
- assertEquals(expected, results);
+ assertEquals(expected, results);
+ }
}
@Test
@@ -316,6 +303,84 @@ public void testCloseRocksDBIterator() throws Exception {
assertTrue(!dbPathForCloseTest.exists());
}
+ @Test
+ public void testHasNextAfterIteratorClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // iter should be true
+ assertTrue(iter.hasNext());
+ // close iter
+ iter.close();
+ // iter.hasNext should be false after iter close
+ assertFalse(iter.hasNext());
+ }
+
+ @Test
+ public void testHasNextAfterDBClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // iter should be true
+ assertTrue(iter.hasNext());
+ // close db
+ db.close();
+ // iter.hasNext should be false after db close
+ assertFalse(iter.hasNext());
+ }
+
+ @Test
+ public void testNextAfterIteratorClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // iter should be true
+ assertTrue(iter.hasNext());
+ // close iter
+ iter.close();
+ // iter.next should throw NoSuchElementException after iter close
+ assertThrows(NoSuchElementException.class, iter::next);
+ }
+
+ @Test
+ public void testNextAfterDBClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // iter should be true
+ assertTrue(iter.hasNext());
+ // close db
+ iter.close();
+ // iter.next should throw NoSuchElementException after db close
+ assertThrows(NoSuchElementException.class, iter::next);
+ }
+
+ @Test
+ public void testSkipAfterIteratorClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // close iter
+ iter.close();
+ // skip should always return false after iter close
+ assertFalse(iter.skip(0));
+ assertFalse(iter.skip(1));
+ }
+
+ @Test
+ public void testSkipAfterDBClose() throws Exception {
+ db.write(createCustomType1(0));
+ KVStoreIterator iter =
+ db.view(CustomType1.class).closeableIterator();
+ // iter should be true
+ assertTrue(iter.hasNext());
+ // close db
+ db.close();
+ // skip should always return false after db close
+ assertFalse(iter.skip(0));
+ assertFalse(iter.skip(1));
+ }
+
private CustomType1 createCustomType1(int i) {
CustomType1 t = new CustomType1();
t.key = "key" + i;
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBTypeInfoSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBTypeInfoSuite.java
index a51fd1a7fea58..f694fd36b68b3 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBTypeInfoSuite.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBTypeInfoSuite.java
@@ -43,34 +43,40 @@ public void testIndexAnnotation() throws Exception {
assertEquals(t1.child, ti.getIndexValue("child", t1));
}
- @Test(expected = IllegalArgumentException.class)
- public void testNoNaturalIndex() throws Exception {
- newTypeInfo(NoNaturalIndex.class);
+ @Test
+ public void testNoNaturalIndex() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(NoNaturalIndex.class));
}
- @Test(expected = IllegalArgumentException.class)
- public void testNoNaturalIndex2() throws Exception {
- newTypeInfo(NoNaturalIndex2.class);
+ @Test
+ public void testNoNaturalIndex2() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(NoNaturalIndex2.class));
}
- @Test(expected = IllegalArgumentException.class)
- public void testDuplicateIndex() throws Exception {
- newTypeInfo(DuplicateIndex.class);
+ @Test
+ public void testDuplicateIndex() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(DuplicateIndex.class));
}
- @Test(expected = IllegalArgumentException.class)
- public void testEmptyIndexName() throws Exception {
- newTypeInfo(EmptyIndexName.class);
+ @Test
+ public void testEmptyIndexName() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(EmptyIndexName.class));
}
- @Test(expected = IllegalArgumentException.class)
- public void testIllegalIndexName() throws Exception {
- newTypeInfo(IllegalIndexName.class);
+ @Test
+ public void testIllegalIndexName() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(IllegalIndexName.class));
}
- @Test(expected = IllegalArgumentException.class)
- public void testIllegalIndexMethod() throws Exception {
- newTypeInfo(IllegalIndexMethod.class);
+ @Test
+ public void testIllegalIndexMethod() {
+ assertThrows(IllegalArgumentException.class,
+ () -> newTypeInfo(IllegalIndexMethod.class));
}
@Test
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index 43740354d84d1..aa8efeb8143e0 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -22,7 +22,7 @@
org.apache.sparkspark-parent_2.12
- 3.3.1
+ 3.4.1../../pom.xml
@@ -42,20 +42,46 @@
+
io.nettynetty-all
+
+ io.netty
+ netty-transport-native-epoll
+ linux-x86_64
+
+
+ io.netty
+ netty-transport-native-epoll
+ linux-aarch_64
+
+
+ io.netty
+ netty-transport-native-kqueue
+ osx-aarch_64
+
+
+ io.netty
+ netty-transport-native-kqueue
+ osx-x86_64
+
+
+
org.apache.commonscommons-lang3
-
${leveldbjni.group}leveldbjni-all1.8
+
+ org.rocksdb
+ rocksdbjni
+ com.fasterxml.jackson.core
@@ -118,14 +144,13 @@
org.apache.logging.log4j
- log4j-slf4j-impl
+ log4j-slf4j2-impltestorg.apache.sparkspark-tags_${scala.binary.version}
- test
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.12
+ 3.4.1
+ ../../pom.xml
+
+
+ spark-avro_2.12
+
+ avro
+
+ jar
+ Spark Avro
+ https://spark.apache.org/
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ test
+
+
+ org.apache.spark
+ spark-tags_${scala.binary.version}
+
+
+
+ org.tukaani
+ xz
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
diff --git a/external/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java b/connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java
similarity index 93%
rename from external/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java
rename to connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java
index a4555844b5117..b2a57060fc2d9 100644
--- a/external/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java
+++ b/connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java
@@ -25,6 +25,7 @@
import org.apache.avro.file.CodecFactory;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.generic.GenericData;
+import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.mapred.AvroKey;
import org.apache.avro.mapreduce.AvroKeyOutputFormat;
@@ -46,13 +47,14 @@ static class SparkRecordWriterFactory extends RecordWriterFactory
this.metadata = metadata;
}
+ @Override
protected RecordWriter, NullWritable> create(
Schema writerSchema,
GenericData dataModel,
CodecFactory compressionCodec,
OutputStream outputStream,
int syncInterval) throws IOException {
- return new SparkAvroKeyRecordWriter(
+ return new SparkAvroKeyRecordWriter<>(
writerSchema, dataModel, compressionCodec, outputStream, syncInterval, metadata);
}
}
@@ -71,7 +73,7 @@ class SparkAvroKeyRecordWriter extends RecordWriter, NullWritable>
OutputStream outputStream,
int syncInterval,
Map metadata) throws IOException {
- this.mAvroFileWriter = new DataFileWriter(dataModel.createDatumWriter(writerSchema));
+ this.mAvroFileWriter = new DataFileWriter<>(new GenericDatumWriter<>(writerSchema, dataModel));
for (Map.Entry entry : metadata.entrySet()) {
this.mAvroFileWriter.setMeta(entry.getKey(), entry.getValue());
}
@@ -80,14 +82,17 @@ class SparkAvroKeyRecordWriter extends RecordWriter, NullWritable>
this.mAvroFileWriter.create(writerSchema, outputStream);
}
+ @Override
public void write(AvroKey record, NullWritable ignore) throws IOException {
this.mAvroFileWriter.append(record.datum());
}
+ @Override
public void close(TaskAttemptContext context) throws IOException {
this.mAvroFileWriter.close();
}
+ @Override
public long sync() throws IOException {
return this.mAvroFileWriter.sync();
}
diff --git a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/connector/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
similarity index 100%
rename from external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
rename to connector/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
similarity index 100%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
similarity index 99%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index 1192856ae7796..aac979cddb2dd 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -29,7 +29,7 @@ import org.apache.avro.Schema.Type._
import org.apache.avro.generic._
import org.apache.avro.util.Utf8
-import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField}
+import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, AvroMatchedField}
import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
@@ -289,8 +289,7 @@ private[sql] class AvroDeserializer(
updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
case (UNION, _) =>
- val allTypes = avroType.getTypes.asScala
- val nonNullTypes = allTypes.filter(_.getType != NULL)
+ val nonNullTypes = nonNullUnionBranches(avroType)
val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava)
if (nonNullTypes.nonEmpty) {
if (nonNullTypes.length == 1) {
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
similarity index 97%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
index a13e0624f351d..3e16e12108129 100755
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.avro
import java.io._
-import java.net.URI
import scala.util.control.NonFatal
@@ -96,9 +95,9 @@ private[sql] class AvroFileFormat extends FileFormat
// Doing input file filtering is improper because we may generate empty tasks that process no
// input files but stress the scheduler. We should probably add a more general input file
// filtering mechanism for `FileFormat` data sources. See SPARK-16317.
- if (parsedOptions.ignoreExtension || file.filePath.endsWith(".avro")) {
+ if (parsedOptions.ignoreExtension || file.urlEncodedPath.endsWith(".avro")) {
val reader = {
- val in = new FsInput(new Path(new URI(file.filePath)), conf)
+ val in = new FsInput(file.toPath, conf)
try {
val datumReader = userProvidedSchema match {
case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema)
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
similarity index 78%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
index fec2b77773ddc..95001bb81508c 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode}
import org.apache.spark.sql.internal.SQLConf
@@ -33,7 +34,10 @@ import org.apache.spark.sql.internal.SQLConf
*/
private[sql] class AvroOptions(
@transient val parameters: CaseInsensitiveMap[String],
- @transient val conf: Configuration) extends Logging with Serializable {
+ @transient val conf: Configuration)
+ extends FileSourceOptions(parameters) with Logging {
+
+ import AvroOptions._
def this(parameters: Map[String, String], conf: Configuration) = {
this(CaseInsensitiveMap(parameters), conf)
@@ -52,8 +56,8 @@ private[sql] class AvroOptions(
* instead of "string" type in the default converted schema.
*/
val schema: Option[Schema] = {
- parameters.get("avroSchema").map(new Schema.Parser().setValidateDefaults(false).parse).orElse({
- val avroUrlSchema = parameters.get("avroSchemaUrl").map(url => {
+ parameters.get(AVRO_SCHEMA).map(new Schema.Parser().setValidateDefaults(false).parse).orElse({
+ val avroUrlSchema = parameters.get(AVRO_SCHEMA_URL).map(url => {
log.debug("loading avro schema from url: " + url)
val fs = FileSystem.get(new URI(url), conf)
val in = fs.open(new Path(url))
@@ -73,20 +77,20 @@ private[sql] class AvroOptions(
* whose field names do not match. Defaults to false.
*/
val positionalFieldMatching: Boolean =
- parameters.get("positionalFieldMatching").exists(_.toBoolean)
+ parameters.get(POSITIONAL_FIELD_MATCHING).exists(_.toBoolean)
/**
* Top level record name in write result, which is required in Avro spec.
- * See https://avro.apache.org/docs/1.11.0/spec.html#schema_record .
+ * See https://avro.apache.org/docs/1.11.1/specification/#schema-record .
* Default value is "topLevelRecord"
*/
- val recordName: String = parameters.getOrElse("recordName", "topLevelRecord")
+ val recordName: String = parameters.getOrElse(RECORD_NAME, "topLevelRecord")
/**
* Record namespace in write result. Default value is "".
- * See Avro spec for details: https://avro.apache.org/docs/1.11.0/spec.html#schema_record .
+ * See Avro spec for details: https://avro.apache.org/docs/1.11.1/specification/#schema-record .
*/
- val recordNamespace: String = parameters.getOrElse("recordNamespace", "")
+ val recordNamespace: String = parameters.getOrElse(RECORD_NAMESPACE, "")
/**
* The `ignoreExtension` option controls ignoring of files without `.avro` extensions in read.
@@ -102,7 +106,7 @@ private[sql] class AvroOptions(
ignoreFilesWithoutExtensionByDefault)
parameters
- .get(AvroOptions.ignoreExtensionKey)
+ .get(IGNORE_EXTENSION)
.map(_.toBoolean)
.getOrElse(!ignoreFilesWithoutExtension)
}
@@ -114,21 +118,21 @@ private[sql] class AvroOptions(
* taken into account. If the former one is not set too, the `snappy` codec is used by default.
*/
val compression: String = {
- parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec)
+ parameters.get(COMPRESSION).getOrElse(SQLConf.get.avroCompressionCodec)
}
val parseMode: ParseMode =
- parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)
+ parameters.get(MODE).map(ParseMode.fromString).getOrElse(FailFastMode)
/**
* The rebasing mode for the DATE and TIMESTAMP_MICROS, TIMESTAMP_MILLIS values in reads.
*/
val datetimeRebaseModeInRead: String = parameters
- .get(AvroOptions.DATETIME_REBASE_MODE)
+ .get(DATETIME_REBASE_MODE)
.getOrElse(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_READ))
}
-private[sql] object AvroOptions {
+private[sql] object AvroOptions extends DataSourceOptions {
def apply(parameters: Map[String, String]): AvroOptions = {
val hadoopConf = SparkSession
.getActiveSession
@@ -137,11 +141,17 @@ private[sql] object AvroOptions {
new AvroOptions(CaseInsensitiveMap(parameters), hadoopConf)
}
- val ignoreExtensionKey = "ignoreExtension"
-
+ val IGNORE_EXTENSION = newOption("ignoreExtension")
+ val MODE = newOption("mode")
+ val RECORD_NAME = newOption("recordName")
+ val COMPRESSION = newOption("compression")
+ val AVRO_SCHEMA = newOption("avroSchema")
+ val AVRO_SCHEMA_URL = newOption("avroSchemaUrl")
+ val RECORD_NAMESPACE = newOption("recordNamespace")
+ val POSITIONAL_FIELD_MATCHING = newOption("positionalFieldMatching")
// The option controls rebasing of the DATE and TIMESTAMP values between
// Julian and Proleptic Gregorian calendars. It impacts on the behaviour of the Avro
// datasource similarly to the SQL config `spark.sql.avro.datetimeRebaseModeInRead`,
// and can be set to the same values: `EXCEPTION`, `LEGACY` or `CORRECTED`.
- val DATETIME_REBASE_MODE = "datetimeRebaseMode"
+ val DATETIME_REBASE_MODE = newOption("datetimeRebaseMode")
}
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala
similarity index 100%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala
similarity index 100%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
similarity index 78%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 4a82df6ba0dce..c95d731f0dedd 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -32,7 +32,7 @@ import org.apache.avro.generic.GenericData.Record
import org.apache.avro.util.Utf8
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField}
+import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, AvroMatchedField}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -218,6 +218,17 @@ private[sql] class AvroSerializer(
val numFields = st.length
(getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))
+ case (st: StructType, UNION) =>
+ val unionConvertor = newComplexUnionConverter(st, avroType, catalystPath, avroPath)
+ val numFields = st.length
+ (getter, ordinal) => unionConvertor(getter.getStruct(ordinal, numFields))
+
+ case (DoubleType, UNION) if nonNullUnionTypes(avroType) == Set(FLOAT, DOUBLE) =>
+ (getter, ordinal) => getter.getDouble(ordinal)
+
+ case (LongType, UNION) if nonNullUnionTypes(avroType) == Set(INT, LONG) =>
+ (getter, ordinal) => getter.getLong(ordinal)
+
case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
val valueConverter = newConverter(
vt, resolveNullableType(avroType.getValueType, valueContainsNull),
@@ -287,14 +298,59 @@ private[sql] class AvroSerializer(
result
}
+ /**
+ * Complex unions map to struct types where field names are member0, member1, etc.
+ * This is consistent with the behavior in [[SchemaConverters]] and when converting between Avro
+ * and Parquet.
+ */
+ private def newComplexUnionConverter(
+ catalystStruct: StructType,
+ unionType: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): InternalRow => Any = {
+ val nonNullTypes = nonNullUnionBranches(unionType)
+ val expectedFieldNames = nonNullTypes.indices.map(i => s"member$i")
+ val catalystFieldNames = catalystStruct.fieldNames.toSeq
+ if (positionalFieldMatch) {
+ if (expectedFieldNames.length != catalystFieldNames.length) {
+ throw new IncompatibleSchemaException(s"Generic Avro union at ${toFieldStr(avroPath)} " +
+ s"does not match the SQL schema at ${toFieldStr(catalystPath)}. It expected the " +
+ s"${expectedFieldNames.length} members but got ${catalystFieldNames.length}")
+ }
+ } else {
+ if (catalystFieldNames != expectedFieldNames) {
+ throw new IncompatibleSchemaException(s"Generic Avro union at ${toFieldStr(avroPath)} " +
+ s"does not match the SQL schema at ${toFieldStr(catalystPath)}. It expected the " +
+ s"following members ${expectedFieldNames.mkString("(", ", ", ")")} but got " +
+ s"${catalystFieldNames.mkString("(", ", ", ")")}")
+ }
+ }
+
+ val unionBranchConverters = nonNullTypes.zip(catalystStruct).map { case (unionBranch, cf) =>
+ newConverter(cf.dataType, unionBranch, catalystPath :+ cf.name, avroPath :+ cf.name)
+ }.toArray
+
+ val numBranches = catalystStruct.length
+ row: InternalRow => {
+ var idx = 0
+ var retVal: Any = null
+ while (idx < numBranches && retVal == null) {
+ if (!row.isNullAt(idx)) {
+ retVal = unionBranchConverters(idx).apply(row, idx)
+ }
+ idx += 1
+ }
+ retVal
+ }
+ }
+
/**
* Resolve a possibly nullable Avro Type.
*
- * An Avro type is nullable when it is a [[UNION]] of two types: one null type and another
- * non-null type. This method will check the nullability of the input Avro type and return the
- * non-null type within when it is nullable. Otherwise it will return the input Avro type
- * unchanged. It will throw an [[UnsupportedAvroTypeException]] when the input Avro type is an
- * unsupported nullable type.
+ * An Avro type is nullable when it is a [[UNION]] which contains a null type. This method will
+ * check the nullability of the input Avro type.
+ * Returns the non-null type within the union when it contains only 1 non-null type.
+ * Otherwise it will return the input Avro type unchanged.
*
* It will also log a warning message if the nullability for Avro and catalyst types are
* different.
@@ -306,20 +362,18 @@ private[sql] class AvroSerializer(
}
/**
- * Check the nullability of the input Avro type and resolve it when it is nullable. The first
- * return value is a [[Boolean]] indicating if the input Avro type is nullable. The second
- * return value is the possibly resolved type.
+ * Check the nullability of the input Avro type and resolve it when it is a single nullable type.
+ * The first return value is a [[Boolean]] indicating if the input Avro type is nullable.
+ * The second return value is the possibly resolved type otherwise the input Avro type unchanged.
*/
private def resolveAvroType(avroType: Schema): (Boolean, Schema) = {
if (avroType.getType == Type.UNION) {
- val fields = avroType.getTypes.asScala
- val actualType = fields.filter(_.getType != Type.NULL)
- if (fields.length != 2 || actualType.length != 1) {
- throw new UnsupportedAvroTypeException(
- s"Unsupported Avro UNION type $avroType: Only UNION of a null type and a non-null " +
- "type is supported")
+ val containsNull = avroType.getTypes.asScala.exists(_.getType == Schema.Type.NULL)
+ nonNullUnionBranches(avroType) match {
+ case Seq() => (true, Schema.create(Type.NULL))
+ case Seq(singleType) => (containsNull, singleType)
+ case _ => (containsNull, avroType)
}
- (true, actualType.head)
} else {
(false, avroType)
}
@@ -337,4 +391,8 @@ private[sql] class AvroSerializer(
"schema will throw runtime exception if there is a record with null value.")
}
}
+
+ private def nonNullUnionTypes(avroType: Schema): Set[Type] = {
+ nonNullUnionBranches(avroType).map(_.getType).toSet
+ }
}
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
similarity index 95%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
index de3626b1f3147..e1966bd1041c2 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
@@ -34,8 +34,9 @@ import org.apache.hadoop.mapreduce.Job
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.avro.AvroOptions.ignoreExtensionKey
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.avro.AvroOptions.IGNORE_EXTENSION
+import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow}
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.OutputWriterFactory
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -49,15 +50,15 @@ private[sql] object AvroUtils extends Logging {
val conf = spark.sessionState.newHadoopConfWithOptions(options)
val parsedOptions = new AvroOptions(options, conf)
- if (parsedOptions.parameters.contains(ignoreExtensionKey)) {
- logWarning(s"Option $ignoreExtensionKey is deprecated. Please use the " +
+ if (parsedOptions.parameters.contains(IGNORE_EXTENSION)) {
+ logWarning(s"Option $IGNORE_EXTENSION is deprecated. Please use the " +
"general data source option pathGlobFilter for filtering file names.")
}
// User can specify an optional avro json schema.
val avroSchema = parsedOptions.schema
.getOrElse {
inferAvroSchemaFromFiles(files, conf, parsedOptions.ignoreExtension,
- spark.sessionState.conf.ignoreCorruptFiles)
+ new FileSourceOptions(CaseInsensitiveMap(options)).ignoreCorruptFiles)
}
SchemaConverters.toSqlType(avroSchema).dataType match {
@@ -335,4 +336,9 @@ private[sql] object AvroUtils extends Logging {
private[avro] def isNullable(avroField: Schema.Field): Boolean =
avroField.schema().getType == Schema.Type.UNION &&
avroField.schema().getTypes.asScala.exists(_.getType == Schema.Type.NULL)
+
+ /** Collect all non null branches of a union in order. */
+ private[avro] def nonNullUnionBranches(avroType: Schema): Seq[Schema] = {
+ avroType.getTypes.asScala.filter(_.getType != Schema.Type.NULL).toSeq
+ }
}
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
similarity index 100%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/CatalystDataToAvro.scala
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
new file mode 100644
index 0000000000000..f616cfa9b5d5c
--- /dev/null
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import scala.collection.JavaConverters._
+
+import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
+import org.apache.avro.LogicalTypes.{Date, Decimal, LocalTimestampMicros, LocalTimestampMillis, TimestampMicros, TimestampMillis}
+import org.apache.avro.Schema.Type._
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.Decimal.minBytesForPrecision
+
+/**
+ * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice
+ * versa.
+ */
+@DeveloperApi
+object SchemaConverters {
+ private lazy val nullSchema = Schema.create(Schema.Type.NULL)
+
+ /**
+ * Internal wrapper for SQL data type and nullability.
+ *
+ * @since 2.4.0
+ */
+ case class SchemaType(dataType: DataType, nullable: Boolean)
+
+ /**
+ * Converts an Avro schema to a corresponding Spark SQL schema.
+ *
+ * @since 2.4.0
+ */
+ def toSqlType(avroSchema: Schema): SchemaType = {
+ toSqlTypeHelper(avroSchema, Set.empty)
+ }
+
+ // The property specifies Catalyst type of the given field
+ private val CATALYST_TYPE_PROP_NAME = "spark.sql.catalyst.type"
+
+ private def toSqlTypeHelper(avroSchema: Schema, existingRecordNames: Set[String]): SchemaType = {
+ avroSchema.getType match {
+ case INT => avroSchema.getLogicalType match {
+ case _: Date => SchemaType(DateType, nullable = false)
+ case _ =>
+ val catalystTypeAttrValue = avroSchema.getProp(CATALYST_TYPE_PROP_NAME)
+ val catalystType = if (catalystTypeAttrValue == null) {
+ IntegerType
+ } else {
+ CatalystSqlParser.parseDataType(catalystTypeAttrValue)
+ }
+ SchemaType(catalystType, nullable = false)
+ }
+ case STRING => SchemaType(StringType, nullable = false)
+ case BOOLEAN => SchemaType(BooleanType, nullable = false)
+ case BYTES | FIXED => avroSchema.getLogicalType match {
+ // For FIXED type, if the precision requires more bytes than fixed size, the logical
+ // type will be null, which is handled by Avro library.
+ case d: Decimal => SchemaType(DecimalType(d.getPrecision, d.getScale), nullable = false)
+ case _ => SchemaType(BinaryType, nullable = false)
+ }
+
+ case DOUBLE => SchemaType(DoubleType, nullable = false)
+ case FLOAT => SchemaType(FloatType, nullable = false)
+ case LONG => avroSchema.getLogicalType match {
+ case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false)
+ case _: LocalTimestampMillis | _: LocalTimestampMicros =>
+ SchemaType(TimestampNTZType, nullable = false)
+ case _ =>
+ val catalystTypeAttrValue = avroSchema.getProp(CATALYST_TYPE_PROP_NAME)
+ val catalystType = if (catalystTypeAttrValue == null) {
+ LongType
+ } else {
+ CatalystSqlParser.parseDataType(catalystTypeAttrValue)
+ }
+ SchemaType(catalystType, nullable = false)
+ }
+
+ case ENUM => SchemaType(StringType, nullable = false)
+
+ case NULL => SchemaType(NullType, nullable = true)
+
+ case RECORD =>
+ if (existingRecordNames.contains(avroSchema.getFullName)) {
+ throw new IncompatibleSchemaException(s"""
+ |Found recursive reference in Avro schema, which can not be processed by Spark:
+ |${avroSchema.toString(true)}
+ """.stripMargin)
+ }
+ val newRecordNames = existingRecordNames + avroSchema.getFullName
+ val fields = avroSchema.getFields.asScala.map { f =>
+ val schemaType = toSqlTypeHelper(f.schema(), newRecordNames)
+ StructField(f.name, schemaType.dataType, schemaType.nullable)
+ }
+
+ SchemaType(StructType(fields.toArray), nullable = false)
+
+ case ARRAY =>
+ val schemaType = toSqlTypeHelper(avroSchema.getElementType, existingRecordNames)
+ SchemaType(
+ ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
+ nullable = false)
+
+ case MAP =>
+ val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames)
+ SchemaType(
+ MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
+ nullable = false)
+
+ case UNION =>
+ if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
+ // In case of a union with null, eliminate it and make a recursive call
+ val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema)
+ if (remainingUnionTypes.size == 1) {
+ toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames).copy(nullable = true)
+ } else {
+ toSqlTypeHelper(Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames)
+ .copy(nullable = true)
+ }
+ } else avroSchema.getTypes.asScala.map(_.getType).toSeq match {
+ case Seq(t1) =>
+ toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames)
+ case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
+ SchemaType(LongType, nullable = false)
+ case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
+ SchemaType(DoubleType, nullable = false)
+ case _ =>
+ // Convert complex unions to struct types where field names are member0, member1, etc.
+ // This is consistent with the behavior when converting between Avro and Parquet.
+ val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
+ case (s, i) =>
+ val schemaType = toSqlTypeHelper(s, existingRecordNames)
+ // All fields are nullable because only one of them is set at a time
+ StructField(s"member$i", schemaType.dataType, nullable = true)
+ }
+
+ SchemaType(StructType(fields.toArray), nullable = false)
+ }
+
+ case other => throw new IncompatibleSchemaException(s"Unsupported type $other")
+ }
+ }
+
+ /**
+ * Converts a Spark SQL schema to a corresponding Avro schema.
+ *
+ * @since 2.4.0
+ */
+ def toAvroType(
+ catalystType: DataType,
+ nullable: Boolean = false,
+ recordName: String = "topLevelRecord",
+ nameSpace: String = "")
+ : Schema = {
+ val builder = SchemaBuilder.builder()
+
+ val schema = catalystType match {
+ case BooleanType => builder.booleanType()
+ case ByteType | ShortType | IntegerType => builder.intType()
+ case LongType => builder.longType()
+ case DateType =>
+ LogicalTypes.date().addToSchema(builder.intType())
+ case TimestampType =>
+ LogicalTypes.timestampMicros().addToSchema(builder.longType())
+ case TimestampNTZType =>
+ LogicalTypes.localTimestampMicros().addToSchema(builder.longType())
+
+ case FloatType => builder.floatType()
+ case DoubleType => builder.doubleType()
+ case StringType => builder.stringType()
+ case NullType => builder.nullType()
+ case d: DecimalType =>
+ val avroType = LogicalTypes.decimal(d.precision, d.scale)
+ val fixedSize = minBytesForPrecision(d.precision)
+ // Need to avoid naming conflict for the fixed fields
+ val name = nameSpace match {
+ case "" => s"$recordName.fixed"
+ case _ => s"$nameSpace.$recordName.fixed"
+ }
+ avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize))
+
+ case BinaryType => builder.bytesType()
+ case ArrayType(et, containsNull) =>
+ builder.array()
+ .items(toAvroType(et, containsNull, recordName, nameSpace))
+ case MapType(StringType, vt, valueContainsNull) =>
+ builder.map()
+ .values(toAvroType(vt, valueContainsNull, recordName, nameSpace))
+ case st: StructType =>
+ val childNameSpace = if (nameSpace != "") s"$nameSpace.$recordName" else recordName
+ val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields()
+ st.foreach { f =>
+ val fieldAvroType =
+ toAvroType(f.dataType, f.nullable, f.name, childNameSpace)
+ fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault()
+ }
+ fieldsAssembler.endRecord()
+
+ case ym: YearMonthIntervalType =>
+ val ymIntervalType = builder.intType()
+ ymIntervalType.addProp(CATALYST_TYPE_PROP_NAME, ym.typeName)
+ ymIntervalType
+ case dt: DayTimeIntervalType =>
+ val dtIntervalType = builder.longType()
+ dtIntervalType.addProp(CATALYST_TYPE_PROP_NAME, dt.typeName)
+ dtIntervalType
+
+ // This should never happen.
+ case other => throw new IncompatibleSchemaException(s"Unexpected type $other.")
+ }
+ if (nullable && catalystType != NullType) {
+ Schema.createUnion(schema, nullSchema)
+ } else {
+ schema
+ }
+ }
+}
+
+private[avro] class IncompatibleSchemaException(
+ msg: String, ex: Throwable = null) extends Exception(msg, ex)
+
+private[avro] class UnsupportedAvroTypeException(msg: String) extends Exception(msg)
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala
similarity index 100%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/package.scala
similarity index 100%
rename from external/avro/src/main/scala/org/apache/spark/sql/avro/package.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/avro/package.scala
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala
similarity index 100%
rename from external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
similarity index 90%
rename from external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
index a4dfdbfe68f9c..cc7bd180e8477 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
@@ -16,14 +16,11 @@
*/
package org.apache.spark.sql.v2.avro
-import java.net.URI
-
import scala.util.control.NonFatal
import org.apache.avro.file.DataFileReader
import org.apache.avro.generic.{GenericDatumReader, GenericRecord}
import org.apache.avro.mapred.FsInput
-import org.apache.hadoop.fs.Path
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
@@ -46,7 +43,7 @@ import org.apache.spark.util.SerializableConfiguration
* @param dataSchema Schema of AVRO files.
* @param readDataSchema Required data schema of AVRO files.
* @param partitionSchema Schema of partitions.
- * @param parsedOptions Options for parsing AVRO files.
+ * @param options Options for parsing AVRO files.
*/
case class AvroPartitionReaderFactory(
sqlConf: SQLConf,
@@ -54,17 +51,17 @@ case class AvroPartitionReaderFactory(
dataSchema: StructType,
readDataSchema: StructType,
partitionSchema: StructType,
- parsedOptions: AvroOptions,
+ options: AvroOptions,
filters: Seq[Filter]) extends FilePartitionReaderFactory with Logging {
- private val datetimeRebaseModeInRead = parsedOptions.datetimeRebaseModeInRead
+ private val datetimeRebaseModeInRead = options.datetimeRebaseModeInRead
override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value
- val userProvidedSchema = parsedOptions.schema
+ val userProvidedSchema = options.schema
- if (parsedOptions.ignoreExtension || partitionedFile.filePath.endsWith(".avro")) {
+ if (options.ignoreExtension || partitionedFile.urlEncodedPath.endsWith(".avro")) {
val reader = {
- val in = new FsInput(new Path(new URI(partitionedFile.filePath)), conf)
+ val in = new FsInput(partitionedFile.toPath, conf)
try {
val datumReader = userProvidedSchema match {
case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema)
@@ -104,7 +101,7 @@ case class AvroPartitionReaderFactory(
override val deserializer = new AvroDeserializer(
userProvidedSchema.getOrElse(reader.getSchema),
readDataSchema,
- parsedOptions.positionalFieldMatching,
+ options.positionalFieldMatching,
datetimeRebaseMode,
avroFilters)
override val stopPosition = partitionedFile.start + partitionedFile.length
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala
similarity index 95%
rename from external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala
index d0f38c12427c3..763b9abe4f91b 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala
@@ -70,10 +70,6 @@ case class AvroScan(
override def hashCode(): Int = super.hashCode()
- override def description(): String = {
- super.description() + ", PushedFilters: " + pushedFilters.mkString("[", ", ", "]")
- }
-
override def getMetaData(): Map[String, String] = {
super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters))
}
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala
similarity index 94%
rename from external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala
index 8fae89a945826..754c58e65b016 100644
--- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala
@@ -18,14 +18,13 @@ package org.apache.spark.sql.v2.avro
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.StructFilters
-import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
-class AvroScanBuilder (
+case class AvroScanBuilder (
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
schema: StructType,
@@ -33,7 +32,7 @@ class AvroScanBuilder (
options: CaseInsensitiveStringMap)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
- override def build(): Scan = {
+ override def build(): AvroScan = {
AvroScan(
sparkSession,
fileIndex,
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala
similarity index 100%
rename from external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala
diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWrite.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWrite.scala
similarity index 100%
rename from external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWrite.scala
rename to connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWrite.scala
diff --git a/external/avro/src/test/java/org/apache/spark/sql/avro/JavaAvroFunctionsSuite.java b/connector/avro/src/test/java/org/apache/spark/sql/avro/JavaAvroFunctionsSuite.java
similarity index 100%
rename from external/avro/src/test/java/org/apache/spark/sql/avro/JavaAvroFunctionsSuite.java
rename to connector/avro/src/test/java/org/apache/spark/sql/avro/JavaAvroFunctionsSuite.java
diff --git a/external/avro/src/test/resources/before_1582_date_v2_4_5.avro b/connector/avro/src/test/resources/before_1582_date_v2_4_5.avro
similarity index 100%
rename from external/avro/src/test/resources/before_1582_date_v2_4_5.avro
rename to connector/avro/src/test/resources/before_1582_date_v2_4_5.avro
diff --git a/external/avro/src/test/resources/before_1582_date_v2_4_6.avro b/connector/avro/src/test/resources/before_1582_date_v2_4_6.avro
similarity index 100%
rename from external/avro/src/test/resources/before_1582_date_v2_4_6.avro
rename to connector/avro/src/test/resources/before_1582_date_v2_4_6.avro
diff --git a/external/avro/src/test/resources/before_1582_date_v3_2_0.avro b/connector/avro/src/test/resources/before_1582_date_v3_2_0.avro
similarity index 100%
rename from external/avro/src/test/resources/before_1582_date_v3_2_0.avro
rename to connector/avro/src/test/resources/before_1582_date_v3_2_0.avro
diff --git a/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro b/connector/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro
similarity index 100%
rename from external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro
rename to connector/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro
diff --git a/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro b/connector/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro
similarity index 100%
rename from external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro
rename to connector/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro
diff --git a/external/avro/src/test/resources/before_1582_timestamp_micros_v3_2_0.avro b/connector/avro/src/test/resources/before_1582_timestamp_micros_v3_2_0.avro
similarity index 100%
rename from external/avro/src/test/resources/before_1582_timestamp_micros_v3_2_0.avro
rename to connector/avro/src/test/resources/before_1582_timestamp_micros_v3_2_0.avro
diff --git a/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro b/connector/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro
similarity index 100%
rename from external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro
rename to connector/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro
diff --git a/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro b/connector/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro
similarity index 100%
rename from external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro
rename to connector/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro
diff --git a/external/avro/src/test/resources/before_1582_timestamp_millis_v3_2_0.avro b/connector/avro/src/test/resources/before_1582_timestamp_millis_v3_2_0.avro
similarity index 100%
rename from external/avro/src/test/resources/before_1582_timestamp_millis_v3_2_0.avro
rename to connector/avro/src/test/resources/before_1582_timestamp_millis_v3_2_0.avro
diff --git a/external/avro/src/test/resources/episodes.avro b/connector/avro/src/test/resources/episodes.avro
similarity index 100%
rename from external/avro/src/test/resources/episodes.avro
rename to connector/avro/src/test/resources/episodes.avro
diff --git a/external/avro/src/test/resources/log4j2.properties b/connector/avro/src/test/resources/log4j2.properties
similarity index 100%
rename from external/avro/src/test/resources/log4j2.properties
rename to connector/avro/src/test/resources/log4j2.properties
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00000.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00000.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00000.avro
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00001.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00001.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00001.avro
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00002.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00002.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00002.avro
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00003.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00003.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00003.avro
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00004.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00004.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00004.avro
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00005.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00005.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00005.avro
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00006.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00006.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00006.avro
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00007.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00007.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00007.avro
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00008.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00008.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00008.avro
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00009.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00009.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00009.avro
diff --git a/external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro b/connector/avro/src/test/resources/test-random-partitioned/part-r-00010.avro
similarity index 100%
rename from external/avro/src/test/resources/test-random-partitioned/part-r-00010.avro
rename to connector/avro/src/test/resources/test-random-partitioned/part-r-00010.avro
diff --git a/external/avro/src/test/resources/test.avro b/connector/avro/src/test/resources/test.avro
similarity index 100%
rename from external/avro/src/test/resources/test.avro
rename to connector/avro/src/test/resources/test.avro
diff --git a/external/avro/src/test/resources/test.avsc b/connector/avro/src/test/resources/test.avsc
similarity index 100%
rename from external/avro/src/test/resources/test.avsc
rename to connector/avro/src/test/resources/test.avsc
diff --git a/external/avro/src/test/resources/test.json b/connector/avro/src/test/resources/test.json
similarity index 100%
rename from external/avro/src/test/resources/test.json
rename to connector/avro/src/test/resources/test.json
diff --git a/external/avro/src/test/resources/test_sub.avsc b/connector/avro/src/test/resources/test_sub.avsc
similarity index 100%
rename from external/avro/src/test/resources/test_sub.avsc
rename to connector/avro/src/test/resources/test_sub.avsc
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
similarity index 100%
rename from external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala
similarity index 100%
rename from external/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
new file mode 100644
index 0000000000000..abc0c3d3155d2
--- /dev/null
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import java.io.ByteArrayOutputStream
+
+import scala.collection.JavaConverters._
+
+import org.apache.avro.{Schema, SchemaBuilder}
+import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, GenericRecordBuilder}
+import org.apache.avro.io.EncoderFactory
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.execution.LocalTableScanExec
+import org.apache.spark.sql.functions.{col, lit, struct}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StructType
+
+class AvroFunctionsSuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ test("roundtrip in to_avro and from_avro - int and string") {
+ val df = spark.range(10).select($"id", $"id".cast("string").as("str"))
+
+ val avroDF = df.select(
+ functions.to_avro($"id").as("a"),
+ functions.to_avro($"str").as("b"))
+ val avroTypeLong = s"""
+ |{
+ | "type": "int",
+ | "name": "id"
+ |}
+ """.stripMargin
+ val avroTypeStr = s"""
+ |{
+ | "type": "string",
+ | "name": "str"
+ |}
+ """.stripMargin
+ checkAnswer(avroDF.select(
+ functions.from_avro($"a", avroTypeLong),
+ functions.from_avro($"b", avroTypeStr)), df)
+ }
+
+ test("roundtrip in to_avro and from_avro - struct") {
+ val df = spark.range(10).select(struct($"id", $"id".cast("string").as("str")).as("struct"))
+ val avroStructDF = df.select(functions.to_avro($"struct").as("avro"))
+ val avroTypeStruct = s"""
+ |{
+ | "type": "record",
+ | "name": "struct",
+ | "fields": [
+ | {"name": "col1", "type": "long"},
+ | {"name": "col2", "type": "string"}
+ | ]
+ |}
+ """.stripMargin
+ checkAnswer(avroStructDF.select(
+ functions.from_avro($"avro", avroTypeStruct)), df)
+ }
+
+ test("handle invalid input in from_avro") {
+ val count = 10
+ val df = spark.range(count).select(struct($"id", $"id".as("id2")).as("struct"))
+ val avroStructDF = df.select(functions.to_avro($"struct").as("avro"))
+ val avroTypeStruct = s"""
+ |{
+ | "type": "record",
+ | "name": "struct",
+ | "fields": [
+ | {"name": "col1", "type": "long"},
+ | {"name": "col2", "type": "double"}
+ | ]
+ |}
+ """.stripMargin
+
+ intercept[SparkException] {
+ avroStructDF.select(
+ functions.from_avro(
+ $"avro", avroTypeStruct, Map("mode" -> "FAILFAST").asJava)).collect()
+ }
+
+ // For PERMISSIVE mode, the result should be row of null columns.
+ val expected = (0 until count).map(_ => Row(Row(null, null)))
+ checkAnswer(
+ avroStructDF.select(
+ functions.from_avro(
+ $"avro", avroTypeStruct, Map("mode" -> "PERMISSIVE").asJava)),
+ expected)
+ }
+
+ test("roundtrip in to_avro and from_avro - array with null") {
+ val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array")
+ val avroTypeArrStruct = s"""
+ |[ {
+ | "type" : "array",
+ | "items" : [ {
+ | "type" : "record",
+ | "name" : "x",
+ | "fields" : [ {
+ | "name" : "y",
+ | "type" : "int"
+ | } ]
+ | }, "null" ]
+ |}, "null" ]
+ """.stripMargin
+ val readBackOne = dfOne.select(functions.to_avro($"array").as("avro"))
+ .select(functions.from_avro($"avro", avroTypeArrStruct).as("array"))
+ checkAnswer(dfOne, readBackOne)
+ }
+
+ test("SPARK-27798: from_avro produces same value when converted to local relation") {
+ val simpleSchema =
+ """
+ |{
+ | "type": "record",
+ | "name" : "Payload",
+ | "fields" : [ {"name" : "message", "type" : "string" } ]
+ |}
+ """.stripMargin
+
+ def generateBinary(message: String, avroSchema: String): Array[Byte] = {
+ val schema = new Schema.Parser().parse(avroSchema)
+ val out = new ByteArrayOutputStream()
+ val writer = new GenericDatumWriter[GenericRecord](schema)
+ val encoder = EncoderFactory.get().binaryEncoder(out, null)
+ val rootRecord = new GenericRecordBuilder(schema).set("message", message).build()
+ writer.write(rootRecord, encoder)
+ encoder.flush()
+ out.toByteArray
+ }
+
+ // This bug is hit when the rule `ConvertToLocalRelation` is run. But the rule was excluded
+ // in `SharedSparkSession`.
+ withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "") {
+ val df = Seq("one", "two", "three", "four").map(generateBinary(_, simpleSchema))
+ .toDF()
+ .withColumn("value",
+ functions.from_avro(col("value"), simpleSchema))
+
+ assert(df.queryExecution.executedPlan.isInstanceOf[LocalTableScanExec])
+ assert(df.collect().map(_.get(0)) === Seq(Row("one"), Row("two"), Row("three"), Row("four")))
+ }
+ }
+
+ test("SPARK-27506: roundtrip in to_avro and from_avro with different compatible schemas") {
+ val df = spark.range(10).select(
+ struct($"id".as("col1"), $"id".cast("string").as("col2")).as("struct")
+ )
+ val avroStructDF = df.select(functions.to_avro($"struct").as("avro"))
+ val actualAvroSchema =
+ s"""
+ |{
+ | "type": "record",
+ | "name": "struct",
+ | "fields": [
+ | {"name": "col1", "type": "int"},
+ | {"name": "col2", "type": "string"}
+ | ]
+ |}
+ |""".stripMargin
+
+ val evolvedAvroSchema =
+ s"""
+ |{
+ | "type": "record",
+ | "name": "struct",
+ | "fields": [
+ | {"name": "col1", "type": "int"},
+ | {"name": "col2", "type": "string"},
+ | {"name": "col3", "type": "string", "default": ""}
+ | ]
+ |}
+ |""".stripMargin
+
+ val expected = spark.range(10).select(
+ struct($"id".as("col1"), $"id".cast("string").as("col2"), lit("").as("col3")).as("struct")
+ )
+
+ checkAnswer(
+ avroStructDF.select(
+ functions.from_avro(
+ $"avro",
+ actualAvroSchema,
+ Map("avroSchema" -> evolvedAvroSchema).asJava)),
+ expected)
+ }
+
+ test("roundtrip in to_avro and from_avro - struct with nullable Avro schema") {
+ val df = spark.range(10).select(struct($"id", $"id".cast("string").as("str")).as("struct"))
+ val avroTypeStruct = s"""
+ |{
+ | "type": "record",
+ | "name": "struct",
+ | "fields": [
+ | {"name": "id", "type": "long"},
+ | {"name": "str", "type": ["null", "string"]}
+ | ]
+ |}
+ """.stripMargin
+ val avroStructDF = df.select(functions.to_avro($"struct", avroTypeStruct).as("avro"))
+ checkAnswer(avroStructDF.select(
+ functions.from_avro($"avro", avroTypeStruct)), df)
+ }
+
+ test("to_avro optional union Avro schema") {
+ val df = spark.range(10).select(struct($"id", $"id".cast("string").as("str")).as("struct"))
+ for (supportedAvroType <- Seq("""["null", "int", "long"]""", """["int", "long"]""")) {
+ val avroTypeStruct = s"""
+ |{
+ | "type": "record",
+ | "name": "struct",
+ | "fields": [
+ | {"name": "id", "type": $supportedAvroType},
+ | {"name": "str", "type": ["null", "string"]}
+ | ]
+ |}
+ """.stripMargin
+ val avroStructDF = df.select(functions.to_avro($"struct", avroTypeStruct).as("avro"))
+ checkAnswer(avroStructDF.select(
+ functions.from_avro($"avro", avroTypeStruct)), df)
+ }
+ }
+
+ test("to_avro complex union Avro schema") {
+ val df = Seq((Some(1), None), (None, Some("a"))).toDF()
+ .select(struct(struct($"_1".as("member0"), $"_2".as("member1")).as("u")).as("struct"))
+ val avroTypeStruct = SchemaBuilder.record("struct").fields()
+ .name("u").`type`().unionOf().intType().and().stringType().endUnion().noDefault()
+ .endRecord().toString
+ val avroStructDF = df.select(functions.to_avro($"struct", avroTypeStruct).as("avro"))
+ checkAnswer(avroStructDF.select(
+ functions.from_avro($"avro", avroTypeStruct)), df)
+ }
+
+ test("SPARK-39775: Disable validate default values when parsing Avro schemas") {
+ val avroTypeStruct = s"""
+ |{
+ | "type": "record",
+ | "name": "struct",
+ | "fields": [
+ | {"name": "id", "type": "long", "default": null}
+ | ]
+ |}
+ """.stripMargin
+ val avroSchema = AvroOptions(Map("avroSchema" -> avroTypeStruct)).schema.get
+ val sparkSchema = SchemaConverters.toSqlType(avroSchema).dataType.asInstanceOf[StructType]
+
+ val df = spark.range(5).select($"id")
+ val structDf = df.select(struct($"id").as("struct"))
+ val avroStructDF = structDf.select(functions.to_avro($"struct", avroTypeStruct).as("avro"))
+ checkAnswer(avroStructDF.select(functions.from_avro($"avro", avroTypeStruct)), structDf)
+
+ withTempPath { dir =>
+ df.write.format("avro").save(dir.getCanonicalPath)
+ checkAnswer(spark.read.schema(sparkSchema).format("avro").load(dir.getCanonicalPath), df)
+
+ val msg = intercept[SparkException] {
+ spark.read.option("avroSchema", avroTypeStruct).format("avro")
+ .load(dir.getCanonicalPath)
+ .collect()
+ }.getCause.getMessage
+ assert(msg.contains("Invalid default for field id: null not a \"long\""))
+ }
+ }
+
+ test("SPARK-39775: Disable validate default values when parsing Avro schemas") {
+ val avroTypeStruct = s"""
+ |{
+ | "type": "record",
+ | "name": "struct",
+ | "fields": [
+ | {"name": "id", "type": "long", "default": null}
+ | ]
+ |}
+ """.stripMargin
+ val avroSchema = AvroOptions(Map("avroSchema" -> avroTypeStruct)).schema.get
+ val sparkSchema = SchemaConverters.toSqlType(avroSchema).dataType.asInstanceOf[StructType]
+
+ val df = spark.range(5).select($"id")
+ val structDf = df.select(struct($"id").as("struct"))
+ val avroStructDF = structDf.select(functions.to_avro('struct, avroTypeStruct).as("avro"))
+ checkAnswer(avroStructDF.select(functions.from_avro('avro, avroTypeStruct)), structDf)
+
+ withTempPath { dir =>
+ df.write.format("avro").save(dir.getCanonicalPath)
+ checkAnswer(spark.read.schema(sparkSchema).format("avro").load(dir.getCanonicalPath), df)
+
+ val msg = intercept[SparkException] {
+ spark.read.option("avroSchema", avroTypeStruct).format("avro")
+ .load(dir.getCanonicalPath)
+ .collect()
+ }.getCause.getMessage
+ assert(msg.contains("Invalid default for field id: null not a \"long\""))
+ }
+ }
+}
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
similarity index 94%
rename from external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
index b7ac10c58e24a..c0022c62735c8 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
@@ -24,7 +24,7 @@ import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.file.DataFileWriter
import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord}
-import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.{SparkArithmeticException, SparkConf, SparkException}
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
@@ -129,7 +129,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
withTempDir { dir =>
val expected = timestampInputData.map(t => Row(new Timestamp(t._1)))
val timestampAvro = timestampFile(dir.getAbsolutePath)
- val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis)
+ val df = spark.read.format("avro").load(timestampAvro).select($"timestamp_millis")
checkAnswer(df, expected)
@@ -144,7 +144,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
withTempDir { dir =>
val expected = timestampInputData.map(t => Row(new Timestamp(t._2)))
val timestampAvro = timestampFile(dir.getAbsolutePath)
- val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros)
+ val df = spark.read.format("avro").load(timestampAvro).select($"timestamp_micros")
checkAnswer(df, expected)
@@ -160,7 +160,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
val expected = timestampInputData.map(t =>
Row(DateTimeUtils.microsToLocalDateTime(DateTimeUtils.millisToMicros(t._3))))
val timestampAvro = timestampFile(dir.getAbsolutePath)
- val df = spark.read.format("avro").load(timestampAvro).select('local_timestamp_millis)
+ val df = spark.read.format("avro").load(timestampAvro).select($"local_timestamp_millis")
checkAnswer(df, expected)
@@ -176,7 +176,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
val expected = timestampInputData.map(t =>
Row(DateTimeUtils.microsToLocalDateTime(DateTimeUtils.millisToMicros(t._4))))
val timestampAvro = timestampFile(dir.getAbsolutePath)
- val df = spark.read.format("avro").load(timestampAvro).select('local_timestamp_micros)
+ val df = spark.read.format("avro").load(timestampAvro).select($"local_timestamp_micros")
checkAnswer(df, expected)
@@ -194,7 +194,8 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
withTempDir { dir =>
val timestampAvro = timestampFile(dir.getAbsolutePath)
val df =
- spark.read.format("avro").load(timestampAvro).select('timestamp_millis, 'timestamp_micros)
+ spark.read.format("avro").load(timestampAvro)
+ .select($"timestamp_millis", $"timestamp_micros")
val expected = timestampInputData.map(t => Row(new Timestamp(t._1), new Timestamp(t._2)))
@@ -226,7 +227,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
withTempDir { dir =>
val timestampAvro = timestampFile(dir.getAbsolutePath)
val df = spark.read.format("avro").load(timestampAvro).select(
- 'local_timestamp_millis, 'local_timestamp_micros)
+ $"local_timestamp_millis", $"local_timestamp_micros")
val expected = timestampInputData.map(t =>
Row(DateTimeUtils.microsToLocalDateTime(DateTimeUtils.millisToMicros(t._3)),
@@ -260,7 +261,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
withTempDir { dir =>
val timestampAvro = timestampFile(dir.getAbsolutePath)
val schema = StructType(StructField("long", TimestampType, true) :: Nil)
- val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long)
+ val df = spark.read.format("avro").schema(schema).load(timestampAvro).select($"long")
val expected = timestampInputData.map(t => Row(new Timestamp(t._5)))
@@ -272,7 +273,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
withTempDir { dir =>
val timestampAvro = timestampFile(dir.getAbsolutePath)
val schema = StructType(StructField("long", TimestampNTZType, true) :: Nil)
- val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long)
+ val df = spark.read.format("avro").schema(schema).load(timestampAvro).select($"long")
val expected = timestampInputData.map(t =>
Row(DateTimeUtils.microsToLocalDateTime(DateTimeUtils.millisToMicros(t._5))))
@@ -432,10 +433,17 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
dataFileWriter.flush()
dataFileWriter.close()
- val msg = intercept[SparkException] {
- spark.read.format("avro").load(s"$dir.avro").collect()
- }.getCause.getCause.getMessage
- assert(msg.contains("Unscaled value too large for precision"))
+ checkError(
+ exception = intercept[SparkException] {
+ spark.read.format("avro").load(s"$dir.avro").collect()
+ }.getCause.getCause.asInstanceOf[SparkArithmeticException],
+ errorClass = "NUMERIC_VALUE_OUT_OF_RANGE",
+ parameters = Map(
+ "value" -> "0",
+ "precision" -> "4",
+ "scale" -> "2",
+ "config" -> "\"spark.sql.ansi.enabled\"")
+ )
}
}
}
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
similarity index 96%
rename from external/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
index 08c61381c5780..046ff4ef088d8 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
@@ -59,11 +59,13 @@ class AvroRowReaderSuite
val df = spark.read.format("avro").load(dir.getCanonicalPath)
val fileScan = df.queryExecution.executedPlan collectFirst {
- case BatchScanExec(_, f: AvroScan, _, _) => f
+ case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
}
val filePath = fileScan.get.fileIndex.inputFiles(0)
val fileSize = new File(new URI(filePath)).length
+ // scalastyle:off pathfromuri
val in = new FsInput(new Path(new URI(filePath)), new Configuration())
+ // scalastyle:on pathfromuri
val reader = DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]())
val it = new Iterator[InternalRow] with AvroUtils.RowReader {
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala
similarity index 100%
rename from external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala
similarity index 100%
rename from external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala
similarity index 100%
rename from external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
similarity index 91%
rename from external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index e93c1c09c9fc2..d19a11b4546a7 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -299,21 +299,27 @@ abstract class AvroSuite
test("Complex Union Type") {
withTempPath { dir =>
- val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4)
- val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava)
- val complexUnionType = Schema.createUnion(
- List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava)
- val fields = Seq(
- new Field("field1", complexUnionType, "doc", null.asInstanceOf[AnyVal]),
- new Field("field2", complexUnionType, "doc", null.asInstanceOf[AnyVal]),
- new Field("field3", complexUnionType, "doc", null.asInstanceOf[AnyVal]),
- new Field("field4", complexUnionType, "doc", null.asInstanceOf[AnyVal])
- ).asJava
- val schema = Schema.createRecord("name", "docs", "namespace", false)
- schema.setFields(fields)
+ val nativeWriterPath = s"$dir.avro"
+ val sparkWriterPath = s"$dir/spark"
+ val fixedSchema = SchemaBuilder.fixed("fixed_name").size(4)
+ val enumSchema = SchemaBuilder.enumeration("enum_name").symbols("e1", "e2")
+ val complexUnionType = SchemaBuilder.unionOf()
+ .intType().and()
+ .stringType().and()
+ .`type`(fixedSchema).and()
+ .`type`(enumSchema).and()
+ .nullType()
+ .endUnion()
+ val schema = SchemaBuilder.record("name").fields()
+ .name("field1").`type`(complexUnionType).noDefault()
+ .name("field2").`type`(complexUnionType).noDefault()
+ .name("field3").`type`(complexUnionType).noDefault()
+ .name("field4").`type`(complexUnionType).noDefault()
+ .name("field5").`type`(complexUnionType).noDefault()
+ .endRecord()
val datumWriter = new GenericDatumWriter[GenericRecord](schema)
val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
- dataFileWriter.create(schema, new File(s"$dir.avro"))
+ dataFileWriter.create(schema, new File(nativeWriterPath))
val avroRec = new GenericData.Record(schema)
val field1 = 1234
val field2 = "Hope that was not load bearing"
@@ -323,15 +329,32 @@ abstract class AvroSuite
avroRec.put("field2", field2)
avroRec.put("field3", new Fixed(fixedSchema, field3))
avroRec.put("field4", new EnumSymbol(enumSchema, field4))
+ avroRec.put("field5", null)
dataFileWriter.append(avroRec)
dataFileWriter.flush()
dataFileWriter.close()
- val df = spark.sqlContext.read.format("avro").load(s"$dir.avro")
- assertResult(field1)(df.selectExpr("field1.member0").first().get(0))
- assertResult(field2)(df.selectExpr("field2.member1").first().get(0))
- assertResult(field3)(df.selectExpr("field3.member2").first().get(0))
- assertResult(field4)(df.selectExpr("field4.member3").first().get(0))
+ val df = spark.sqlContext.read.format("avro").load(nativeWriterPath)
+ assertResult(Row(field1, null, null, null))(df.selectExpr("field1.*").first())
+ assertResult(Row(null, field2, null, null))(df.selectExpr("field2.*").first())
+ assertResult(Row(null, null, field3, null))(df.selectExpr("field3.*").first())
+ assertResult(Row(null, null, null, field4))(df.selectExpr("field4.*").first())
+ assertResult(Row(null, null, null, null))(df.selectExpr("field5.*").first())
+
+ df.write.format("avro").option("avroSchema", schema.toString).save(sparkWriterPath)
+
+ val df2 = spark.sqlContext.read.format("avro").load(nativeWriterPath)
+ assertResult(Row(field1, null, null, null))(df2.selectExpr("field1.*").first())
+ assertResult(Row(null, field2, null, null))(df2.selectExpr("field2.*").first())
+ assertResult(Row(null, null, field3, null))(df2.selectExpr("field3.*").first())
+ assertResult(Row(null, null, null, field4))(df2.selectExpr("field4.*").first())
+ assertResult(Row(null, null, null, null))(df2.selectExpr("field5.*").first())
+
+ val reader = openDatumReader(new File(sparkWriterPath))
+ assert(reader.hasNext)
+ assertResult(avroRec)(reader.next())
+ assert(!reader.hasNext)
+ reader.close()
}
}
@@ -550,8 +573,8 @@ abstract class AvroSuite
val fixed = spark.read.format("avro").load(testAvro).select("fixed3").collect()
assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3))
- val enum = spark.read.format("avro").load(testAvro).select("enum").collect()
- assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS"))
+ val enums = spark.read.format("avro").load(testAvro).select("enum").collect()
+ assert(enums.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS"))
val record = spark.read.format("avro").load(testAvro).select("record").collect()
assert(record(0)(0).getClass.toString.contains("Row"))
@@ -875,7 +898,7 @@ abstract class AvroSuite
dfWithNull.write.format("avro")
.option("avroSchema", avroSchema).save(s"$tempDir/${UUID.randomUUID()}")
}
- assertExceptionMsg[AvroTypeException](e1, "Not an enum: null")
+ assertExceptionMsg[AvroTypeException](e1, "value null is not a SuitEnumType")
// Writing df containing data not in the enum will throw an exception
val e2 = intercept[SparkException] {
@@ -1069,14 +1092,13 @@ abstract class AvroSuite
df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir)
checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir))
- val message = intercept[Exception] {
+ val message = intercept[SparkException] {
spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row(2, null))), catalystSchema)
.write.format("avro").option("avroSchema", avroSchema)
.save(s"$tempDir/${UUID.randomUUID()}")
- }.getCause.getMessage
+ }.getMessage
assert(message.contains("Caused by: java.lang.NullPointerException: "))
- assert(message.contains(
- "null of string in string in field Name of test_schema in test_schema"))
+ assert(message.contains("null value for (non-nullable) string at test_schema.Name"))
}
}
@@ -1144,32 +1166,81 @@ abstract class AvroSuite
}
}
- test("unsupported nullable avro type") {
+ test("int/long double/float conversion") {
val catalystSchema =
StructType(Seq(
- StructField("Age", IntegerType, nullable = false),
- StructField("Name", StringType, nullable = false)))
+ StructField("Age", LongType),
+ StructField("Length", DoubleType),
+ StructField("Name", StringType)))
- for (unsupportedAvroType <- Seq("""["null", "int", "long"]""", """["int", "long"]""")) {
+ for (optionalNull <- Seq(""""null",""", "")) {
val avroSchema = s"""
|{
| "type" : "record",
| "name" : "test_schema",
| "fields" : [
- | {"name": "Age", "type": $unsupportedAvroType},
+ | {"name": "Age", "type": [$optionalNull "int", "long"]},
+ | {"name": "Length", "type": [$optionalNull "float", "double"]},
| {"name": "Name", "type": ["null", "string"]}
| ]
|}
""".stripMargin
val df = spark.createDataFrame(
- spark.sparkContext.parallelize(Seq(Row(2, "Aurora"))), catalystSchema)
+ spark.sparkContext.parallelize(Seq(Row(2L, 1.8D, "Aurora"), Row(1L, 0.9D, null))),
+ catalystSchema)
+
+ withTempPath { tempDir =>
+ df.write.format("avro").option("avroSchema", avroSchema).save(tempDir.getPath)
+ checkAnswer(
+ spark.read
+ .format("avro")
+ .option("avroSchema", avroSchema)
+ .load(tempDir.getPath),
+ df)
+ }
+ }
+ }
+
+ test("non-matching complex union types") {
+ val catalystSchema = new StructType().add("Union", new StructType()
+ .add("member0", IntegerType)
+ .add("member1", new StructType().add("f1", StringType, nullable = false))
+ )
+
+ val df = spark.createDataFrame(
+ spark.sparkContext.parallelize(Seq(Row(Row(1, null)))), catalystSchema)
+
+ val recordS = SchemaBuilder.record("r").fields().requiredString("f1").endRecord()
+ val intS = Schema.create(Schema.Type.INT)
+ val nullS = Schema.create(Schema.Type.NULL)
+ for ((unionTypes, compatible) <- Seq(
+ (Seq(nullS, intS, recordS), true),
+ (Seq(intS, nullS, recordS), true),
+ (Seq(intS, recordS, nullS), true),
+ (Seq(intS, recordS), true),
+ (Seq(nullS, recordS, intS), false),
+ (Seq(nullS, recordS), false),
+ (Seq(nullS, SchemaBuilder.record("r").fields().requiredString("f2").endRecord()), false)
+ )) {
+ val avroSchema = SchemaBuilder.record("test_schema").fields()
+ .name("union").`type`(Schema.createUnion(unionTypes: _*)).noDefault()
+ .endRecord().toString()
withTempPath { tempDir =>
- val message = intercept[SparkException] {
+ if (!compatible) {
+ intercept[SparkException] {
+ df.write.format("avro").option("avroSchema", avroSchema).save(tempDir.getPath)
+ }
+ } else {
df.write.format("avro").option("avroSchema", avroSchema).save(tempDir.getPath)
- }.getCause.getMessage
- assert(message.contains("Only UNION of a null type and a non-null type is supported"))
+ checkAnswer(
+ spark.read
+ .format("avro")
+ .option("avroSchema", avroSchema)
+ .load(tempDir.getPath),
+ df)
+ }
}
}
}
@@ -1182,14 +1253,16 @@ abstract class AvroSuite
sql("select interval 1 days").write.format("avro").mode("overwrite").save(tempDir)
}.getMessage
assert(msg.contains("Cannot save interval data type into external storage.") ||
- msg.contains("AVRO data source does not support interval data type."))
+ msg.contains("Column `INTERVAL '1' DAY` has a data type of interval day, " +
+ "which is not supported by Avro."))
msg = intercept[AnalysisException] {
spark.udf.register("testType", () => new IntervalData())
sql("select testType()").write.format("avro").mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
- .contains(s"avro data source does not support interval data type."))
+ .contains("column `testtype()` has a data type of interval, " +
+ "which is not supported by avro."))
}
}
}
@@ -1803,13 +1876,13 @@ abstract class AvroSuite
spark
.read
.format("avro")
- .option(AvroOptions.ignoreExtensionKey, false)
+ .option(AvroOptions.IGNORE_EXTENSION, false)
.load(dir.getCanonicalPath)
.count()
}
val deprecatedEvents = logAppender.loggingEvents
.filter(_.getMessage.getFormattedMessage.contains(
- s"Option ${AvroOptions.ignoreExtensionKey} is deprecated"))
+ s"Option ${AvroOptions.IGNORE_EXTENSION} is deprecated"))
assert(deprecatedEvents.size === 1)
}
}
@@ -1817,7 +1890,7 @@ abstract class AvroSuite
// It generates input files for the test below:
// "SPARK-31183, SPARK-37705: compatibility with Spark 2.4/3.2 in reading dates/timestamps"
ignore("SPARK-31855: generate test files for checking compatibility with Spark 2.4/3.2") {
- val resourceDir = "external/avro/src/test/resources"
+ val resourceDir = "connector/avro/src/test/resources"
val version = SPARK_VERSION_SHORT.replaceAll("\\.", "_")
def save(
in: Seq[String],
@@ -1932,7 +2005,7 @@ abstract class AvroSuite
val e = intercept[SparkException] {
df.write.format("avro").option("avroSchema", avroSchema).save(path3_x)
}
- assert(e.getCause.getCause.getCause.isInstanceOf[SparkUpgradeException])
+ assert(e.getCause.getCause.isInstanceOf[SparkUpgradeException])
checkDefaultLegacyRead(oldPath)
withSQLConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) {
@@ -2103,12 +2176,15 @@ abstract class AvroSuite
}
private def checkMetaData(path: java.io.File, key: String, expectedValue: String): Unit = {
+ val value = openDatumReader(path).asInstanceOf[DataFileReader[_]].getMetaString(key)
+ assert(value === expectedValue)
+ }
+
+ private def openDatumReader(path: File): org.apache.avro.file.FileReader[GenericRecord] = {
val avroFiles = path.listFiles()
.filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
assert(avroFiles.length === 1)
- val reader = DataFileReader.openReader(avroFiles(0), new GenericDatumReader[GenericRecord]())
- val value = reader.asInstanceOf[DataFileReader[_]].getMetaString(key)
- assert(value === expectedValue)
+ DataFileReader.openReader(avroFiles(0), new GenericDatumReader[GenericRecord]())
}
test("SPARK-31327: Write Spark version into Avro file metadata") {
@@ -2183,7 +2259,7 @@ abstract class AvroSuite
val e = intercept[SparkException] {
df.write.format("avro").option("avroSchema", avroSchema).save(dir.getCanonicalPath)
}
- val errMsg = e.getCause.getCause.getCause.asInstanceOf[SparkUpgradeException].getMessage
+ val errMsg = e.getCause.getCause.asInstanceOf[SparkUpgradeException].getMessage
assert(errMsg.contains("You may get a different result due to the upgrading"))
}
}
@@ -2193,7 +2269,7 @@ abstract class AvroSuite
val e = intercept[SparkException] {
df.write.format("avro").save(dir.getCanonicalPath)
}
- val errMsg = e.getCause.getCause.getCause.asInstanceOf[SparkUpgradeException].getMessage
+ val errMsg = e.getCause.getCause.asInstanceOf[SparkUpgradeException].getMessage
assert(errMsg.contains("You may get a different result due to the upgrading"))
}
}
@@ -2218,14 +2294,18 @@ abstract class AvroSuite
withView("v") {
spark.range(1).createTempView("v")
withTempDir { dir =>
- val e = intercept[AnalysisException] {
- sql(
- s"""
- |CREATE TABLE test_ddl USING AVRO
- |LOCATION '${dir}'
- |AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin)
- }.getMessage
- assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s)."))
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(
+ s"""
+ |CREATE TABLE test_ddl USING AVRO
+ |LOCATION '${dir}'
+ |AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin)
+ },
+ errorClass = "INVALID_COLUMN_NAME_AS_PATH",
+ parameters = Map(
+ "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`")
+ )
}
withTempDir { dir =>
@@ -2271,6 +2351,20 @@ abstract class AvroSuite
checkAnswer(df2, df.collect().toSeq)
}
}
+
+ test("SPARK-40667: validate Avro Options") {
+ assert(AvroOptions.getAllOptions.size == 9)
+ // Please add validation on any new Avro options here
+ assert(AvroOptions.isValidOption("ignoreExtension"))
+ assert(AvroOptions.isValidOption("mode"))
+ assert(AvroOptions.isValidOption("recordName"))
+ assert(AvroOptions.isValidOption("compression"))
+ assert(AvroOptions.isValidOption("avroSchema"))
+ assert(AvroOptions.isValidOption("avroSchemaUrl"))
+ assert(AvroOptions.isValidOption("recordNamespace"))
+ assert(AvroOptions.isValidOption("positionalFieldMatching"))
+ assert(AvroOptions.isValidOption("datetimeRebaseMode"))
+ }
}
class AvroV1Suite extends AvroSuite {
@@ -2283,20 +2377,28 @@ class AvroV1Suite extends AvroSuite {
withView("v") {
spark.range(1).createTempView("v")
withTempDir { dir =>
- val e = intercept[AnalysisException] {
- sql("SELECT ID, IF(ID=1,1,0) FROM v").write.mode(SaveMode.Overwrite)
- .format("avro").save(dir.getCanonicalPath)
- }.getMessage
- assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s)."))
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("SELECT ID, IF(ID=1,1,0) FROM v").write.mode(SaveMode.Overwrite)
+ .format("avro").save(dir.getCanonicalPath)
+ },
+ errorClass = "INVALID_COLUMN_NAME_AS_PATH",
+ parameters = Map(
+ "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`")
+ )
}
withTempDir { dir =>
- val e = intercept[AnalysisException] {
- sql("SELECT NAMED_STRUCT('(IF((ID = 1), 1, 0))', IF(ID=1,ID,0)) AS col1 FROM v")
- .write.mode(SaveMode.Overwrite)
- .format("avro").save(dir.getCanonicalPath)
- }.getMessage
- assert(e.contains("Column name \"(IF((ID = 1), 1, 0))\" contains invalid character(s)."))
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("SELECT NAMED_STRUCT('(IF((ID = 1), 1, 0))', IF(ID=1,ID,0)) AS col1 FROM v")
+ .write.mode(SaveMode.Overwrite)
+ .format("avro").save(dir.getCanonicalPath)
+ },
+ errorClass = "INVALID_COLUMN_NAME_AS_PATH",
+ parameters = Map(
+ "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`")
+ )
}
}
}
@@ -2335,14 +2437,15 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
})
val fileScan = df.queryExecution.executedPlan collectFirst {
- case BatchScanExec(_, f: AvroScan, _, _) => f
+ case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
assert(fileScan.get.dataFilters.nonEmpty)
assert(fileScan.get.planInputPartitions().forall { partition =>
partition.asInstanceOf[FilePartition].files.forall { file =>
- file.filePath.contains("p1=1") && file.filePath.contains("p2=2")
+ file.urlEncodedPath.contains("p1=1") &&
+ file.urlEncodedPath.contains("p2=2")
}
})
checkAnswer(df, Row("b", 1, 2))
@@ -2368,7 +2471,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
assert(filterCondition.isDefined)
val fileScan = df.queryExecution.executedPlan collectFirst {
- case BatchScanExec(_, f: AvroScan, _, _) => f
+ case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
@@ -2408,7 +2511,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
val basePath = dir.getCanonicalPath + "/avro"
val expected_plan_fragment =
s"""
- |\\(1\\) BatchScan
+ |\\(1\\) BatchScan avro file:$basePath
|Output \\[2\\]: \\[value#xL, id#x\\]
|DataFilters: \\[isnotnull\\(value#xL\\), \\(value#xL > 2\\)\\]
|Format: avro
@@ -2449,7 +2552,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
.where("value = 'a'")
val fileScan = df.queryExecution.executedPlan collectFirst {
- case BatchScanExec(_, f: AvroScan, _, _) => f
+ case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
if (filtersPushdown) {
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala
similarity index 89%
rename from external/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala
index cdfa1b118b18d..40ed487087c8a 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala
@@ -34,9 +34,9 @@ class DeprecatedAvroFunctionsSuite extends QueryTest with SharedSparkSession {
import testImplicits._
test("roundtrip in to_avro and from_avro - int and string") {
- val df = spark.range(10).select('id, 'id.cast("string").as("str"))
+ val df = spark.range(10).select($"id", $"id".cast("string").as("str"))
- val avroDF = df.select(to_avro('id).as("a"), to_avro('str).as("b"))
+ val avroDF = df.select(to_avro($"id").as("a"), to_avro($"str").as("b"))
val avroTypeLong = s"""
|{
| "type": "int",
@@ -49,12 +49,12 @@ class DeprecatedAvroFunctionsSuite extends QueryTest with SharedSparkSession {
| "name": "str"
|}
""".stripMargin
- checkAnswer(avroDF.select(from_avro('a, avroTypeLong), from_avro('b, avroTypeStr)), df)
+ checkAnswer(avroDF.select(from_avro($"a", avroTypeLong), from_avro($"b", avroTypeStr)), df)
}
test("roundtrip in to_avro and from_avro - struct") {
- val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct"))
- val avroStructDF = df.select(to_avro('struct).as("avro"))
+ val df = spark.range(10).select(struct($"id", $"id".cast("string").as("str")).as("struct"))
+ val avroStructDF = df.select(to_avro($"struct").as("avro"))
val avroTypeStruct = s"""
|{
| "type": "record",
@@ -65,7 +65,7 @@ class DeprecatedAvroFunctionsSuite extends QueryTest with SharedSparkSession {
| ]
|}
""".stripMargin
- checkAnswer(avroStructDF.select(from_avro('avro, avroTypeStruct)), df)
+ checkAnswer(avroStructDF.select(from_avro($"avro", avroTypeStruct)), df)
}
test("roundtrip in to_avro and from_avro - array with null") {
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala b/connector/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala
similarity index 99%
rename from external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala
index 7368543642b99..aa0d713bbfb77 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroReadBenchmark.scala
@@ -33,8 +33,8 @@ import org.apache.spark.sql.types._
* To run this benchmark:
* 1. without sbt: bin/spark-submit --class
* --jars ,,,
- * 2. build/sbt "avro/test:runMain "
- * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "avro/test:runMain "
+ * 2. build/sbt "avro/Test/runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "avro/Test/runMain "
* Results will be written to "benchmarks/AvroReadBenchmark-results.txt".
* }}}
*/
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala b/connector/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala
similarity index 96%
rename from external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala
index 7f9febb5b14e5..d1db290f34b3b 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/execution/benchmark/AvroWriteBenchmark.scala
@@ -30,8 +30,8 @@ import org.apache.spark.storage.StorageLevel
* --jars ,,
* ,
*
- * 2. build/sbt "sql/test:runMain "
- * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "avro/test:runMain "
+ * 2. build/sbt "avro/Test/runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "avro/Test/runMain "
* Results will be written to "benchmarks/AvroWriteBenchmark-results.txt".
* }}}
*/
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/execution/datasources/AvroReadSchemaSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/execution/datasources/AvroReadSchemaSuite.scala
similarity index 100%
rename from external/avro/src/test/scala/org/apache/spark/sql/execution/datasources/AvroReadSchemaSuite.scala
rename to connector/avro/src/test/scala/org/apache/spark/sql/execution/datasources/AvroReadSchemaSuite.scala
diff --git a/connector/connect/README.md b/connector/connect/README.md
new file mode 100644
index 0000000000000..dfe49cea3df1f
--- /dev/null
+++ b/connector/connect/README.md
@@ -0,0 +1,46 @@
+# Spark Connect
+
+This module contains the implementation of Spark Connect which is a logical plan
+facade for the implementation in Spark. Spark Connect is directly integrated into the build
+of Spark.
+
+The documentation linked here is specifically for developers of Spark Connect and not
+directly intended to be end-user documentation.
+
+## Development Topics
+
+### Guidelines for new clients
+
+When contributing a new client please be aware that we strive to have a common
+user experience across all languages. Please follow the below guidelines:
+
+* [Connection string configuration](docs/client-connection-string.md)
+* [Adding new messages](docs/adding-proto-messages.md) in the Spark Connect protocol.
+
+### Python client development
+
+Python-specific development guidelines are located in [python/docs/source/development/testing.rst](https://github.com/apache/spark/blob/master/python/docs/source/development/testing.rst) that is published at [Development tab](https://spark.apache.org/docs/latest/api/python/development/index.html) in PySpark documentation.
+
+### Build with user-defined `protoc` and `protoc-gen-grpc-java`
+
+When the user cannot use the official `protoc` and `protoc-gen-grpc-java` binary files to build the `connect` module in the compilation environment,
+for example, compiling `connect` module on CentOS 6 or CentOS 7 which the default `glibc` version is less than 2.14, we can try to compile and test by
+specifying the user-defined `protoc` and `protoc-gen-grpc-java` binary files as follows:
+
+```bash
+export SPARK_PROTOC_EXEC_PATH=/path-to-protoc-exe
+export CONNECT_PLUGIN_EXEC_PATH=/path-to-protoc-gen-grpc-java-exe
+./build/mvn -Phive -Puser-defined-protoc clean package
+```
+
+or
+
+```bash
+export SPARK_PROTOC_EXEC_PATH=/path-to-protoc-exe
+export CONNECT_PLUGIN_EXEC_PATH=/path-to-protoc-gen-grpc-java-exe
+./build/sbt -Puser-defined-protoc clean package
+```
+
+The user-defined `protoc` and `protoc-gen-grpc-java` binary files can be produced in the user's compilation environment by source code compilation,
+for compilation steps, please refer to [protobuf](https://github.com/protocolbuffers/protobuf) and [grpc-java](https://github.com/grpc/grpc-java).
+
diff --git a/connector/connect/bin/spark-connect b/connector/connect/bin/spark-connect
new file mode 100755
index 0000000000000..772a88a04f3eb
--- /dev/null
+++ b/connector/connect/bin/spark-connect
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Start the spark-connect with server logs printed in the standard output. The script rebuild the
+# server dependencies and start the server at the default port. This can be used to debug client
+# during client development.
+
+# Go to the Spark project root directory
+FWDIR="$(cd "`dirname "$0"`"/../../..; pwd)"
+cd "$FWDIR"
+export SPARK_HOME=$FWDIR
+
+# Determine the Scala version used in Spark
+SCALA_BINARY_VER=`grep "scala.binary.version" "${SPARK_HOME}/pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'`
+SCALA_ARG="-Pscala-${SCALA_BINARY_VER}"
+
+# Build the jars needed for spark submit and spark connect
+build/sbt "${SCALA_ARG}" -Phive -Pconnect package
+
+# This jar is already in the classpath, but the submit commands wants a jar as the input.
+CONNECT_JAR=`ls "${SPARK_HOME}"/assembly/target/scala-"${SCALA_BINARY_VER}"/jars/spark-connect_*.jar | paste -sd ',' -`
+
+exec "${SPARK_HOME}"/bin/spark-submit "$@" --class org.apache.spark.sql.connect.SimpleSparkConnectService "$CONNECT_JAR"
diff --git a/connector/connect/bin/spark-connect-scala-client b/connector/connect/bin/spark-connect-scala-client
new file mode 100755
index 0000000000000..e7a15c56d7c4d
--- /dev/null
+++ b/connector/connect/bin/spark-connect-scala-client
@@ -0,0 +1,48 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Use the spark connect JVM client to connect to a spark connect server.
+#
+# Start a local server:
+# A local spark-connect server with default settings can be started using the following command:
+# `connector/connect/bin/spark-connect`
+# The client should be able to connect to this server directly with the default client settings.
+#
+# Connect to a remote server:
+# To connect to a remote server, use env var `SPARK_REMOTE` to configure the client connection
+# string. e.g.
+# `export SPARK_REMOTE="sc://:/;token=;="`
+
+# Go to the Spark project root directory
+FWDIR="$(cd "`dirname "$0"`"/../../..; pwd)"
+cd "$FWDIR"
+export SPARK_HOME=$FWDIR
+
+# Determine the Scala version used in Spark
+SCALA_BINARY_VER=`grep "scala.binary.version" "${SPARK_HOME}/pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'`
+SCALA_VER=`grep "scala.version" "${SPARK_HOME}/pom.xml" | grep ${SCALA_BINARY_VER} | head -n1 | awk -F '[<>]' '{print $3}'`
+SCALA_ARG="-Pscala-${SCALA_BINARY_VER}"
+
+# Build the jars needed for spark connect JVM client
+build/sbt "${SCALA_ARG}" "sql/package;connect-client-jvm/assembly"
+
+CONNECT_CLASSPATH="$(build/sbt "${SCALA_ARG}" -DcopyDependencies=false "export connect-client-jvm/fullClasspath" | grep jar | tail -n1)"
+SQL_CLASSPATH="$(build/sbt "${SCALA_ARG}" -DcopyDependencies=false "export sql/fullClasspath" | grep jar | tail -n1)"
+
+exec java -cp "$CONNECT_CLASSPATH:$SQL_CLASSPATH" org.apache.spark.sql.application.ConnectRepl "$@"
\ No newline at end of file
diff --git a/connector/connect/bin/spark-connect-shell b/connector/connect/bin/spark-connect-shell
new file mode 100755
index 0000000000000..0fcf831e03db1
--- /dev/null
+++ b/connector/connect/bin/spark-connect-shell
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# The spark connect shell for development. This shell script builds the spark connect server with
+# all dependencies and starts the server at the default port.
+# Use `/bin/spark-connect-shell` instead if rebuilding the dependency jars are not needed.
+
+# Go to the Spark project root directory
+FWDIR="$(cd "`dirname "$0"`"/../../..; pwd)"
+cd "$FWDIR"
+export SPARK_HOME=$FWDIR
+
+# Determine the Scala version used in Spark
+SCALA_BINARY_VER=`grep "scala.binary.version" "${SPARK_HOME}/pom.xml" | head -n1 | awk -F '[<>]' '{print $3}'`
+SCALA_ARG="-Pscala-${SCALA_BINARY_VER}"
+
+# Build the jars needed for spark submit and spark connect
+build/sbt "${SCALA_ARG}" -Phive -Pconnect package
+
+exec "${SPARK_HOME}"/bin/spark-shell --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin "$@"
diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml
new file mode 100644
index 0000000000000..f16761d3a6ae2
--- /dev/null
+++ b/connector/connect/client/jvm/pom.xml
@@ -0,0 +1,227 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.12
+ 3.4.1
+ ../../../../pom.xml
+
+
+ spark-connect-client-jvm_2.12
+ jar
+ Spark Project Connect Client
+ https://spark.apache.org/
+
+ connect-client-jvm
+ 31.0.1-jre
+ 1.0.1
+ 1.1.0
+
+
+
+
+ org.apache.spark
+ spark-connect-common_${scala.binary.version}
+ ${project.version}
+
+
+ com.google.guava
+ guava
+
+
+
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.google.guava
+ guava
+
+
+
+
+ com.google.protobuf
+ protobuf-java
+ ${protobuf.version}
+ compile
+
+
+ com.google.guava
+ guava
+ ${guava.version}
+ compile
+
+
+ com.google.guava
+ failureaccess
+ ${guava.failureaccess.version}
+ compile
+
+
+ io.netty
+ netty-codec-http2
+ ${netty.version}
+
+
+ io.netty
+ netty-handler-proxy
+ ${netty.version}
+
+
+ io.netty
+ netty-transport-native-unix-common
+ ${netty.version}
+
+
+ com.lihaoyi
+ ammonite_${scala.version}
+ ${ammonite.version}
+ provided
+
+
+ org.scala-lang.modules
+ scala-xml_${scala.binary.version}
+
+
+
+
+ org.apache.spark
+ spark-connect-common_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ com.google.guava
+ guava
+
+
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ test
+
+
+ org.mockito
+ mockito-core
+ test
+
+
+
+ com.typesafe
+ mima-core_${scala.binary.version}
+ ${mima.version}
+ test
+
+
+
+ target/scala-${scala.binary.version}/test-classes
+
+
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+ false
+
+
+ com.google.android:*
+ com.google.api.grpc:*
+ com.google.code.findbugs:*
+ com.google.code.gson:*
+ com.google.errorprone:*
+ com.google.guava:*
+ com.google.j2objc:*
+ com.google.protobuf:*
+ io.grpc:*
+ io.netty:*
+ io.perfmark:*
+ org.codehaus.mojo:*
+ org.checkerframework:*
+ org.apache.spark:spark-connect-common_${scala.binary.version}
+
+
+
+
+ io.grpc
+ ${spark.shade.packageName}.connect.client.io.grpc
+
+ io.grpc.**
+
+
+
+ com.google
+ ${spark.shade.packageName}.connect.client.com.google
+
+
+ io.netty
+ ${spark.shade.packageName}.connect.client.io.netty
+
+
+ org.checkerframework
+ ${spark.shade.packageName}.connect.client.org.checkerframework
+
+
+ javax.annotation
+ ${spark.shade.packageName}.connect.client.javax.annotation
+
+
+ io.perfmark
+ ${spark.shade.packageName}.connect.client.io.perfmark
+
+
+ org.codehaus
+ ${spark.shade.packageName}.connect.client.org.codehaus
+
+
+ android.annotation
+ ${spark.shade.packageName}.connect.client.android.annotation
+
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+
+
+ prepare-test-jar
+ test-compile
+
+ test-jar
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java
new file mode 100644
index 0000000000000..95af157687c85
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql;
+
+import org.apache.spark.annotation.Stable;
+
+/**
+ * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source.
+ *
+ * @since 3.4.0
+ */
+@Stable
+public enum SaveMode {
+ /**
+ * Append mode means that when saving a DataFrame to a data source, if data/table already exists,
+ * contents of the DataFrame are expected to be appended to existing data.
+ *
+ * @since 3.4.0
+ */
+ Append,
+ /**
+ * Overwrite mode means that when saving a DataFrame to a data source,
+ * if data/table already exists, existing data is expected to be overwritten by the contents of
+ * the DataFrame.
+ *
+ * @since 3.4.0
+ */
+ Overwrite,
+ /**
+ * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists,
+ * an exception is expected to be thrown.
+ *
+ * @since 3.4.0
+ */
+ ErrorIfExists,
+ /**
+ * Ignore mode means that when saving a DataFrame to a data source, if data already exists,
+ * the save operation is expected to not save the contents of the DataFrame and to not
+ * change the existing data.
+ *
+ * @since 3.4.0
+ */
+ Ignore
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
new file mode 100644
index 0000000000000..6a660a7482e27
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Column.scala
@@ -0,0 +1,1478 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering
+import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.connect.common.DataTypeProtoConverter
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.types._
+
+/**
+ * A column that will be computed based on the data in a `DataFrame`.
+ *
+ * A new column can be constructed based on the input columns present in a DataFrame:
+ *
+ * {{{
+ * df("columnName") // On a specific `df` DataFrame.
+ * col("columnName") // A generic column not yet associated with a DataFrame.
+ * col("columnName.field") // Extracting a struct field
+ * col("`a.column.with.dots`") // Escape `.` in column names.
+ * $"columnName" // Scala short hand for a named column.
+ * }}}
+ *
+ * [[Column]] objects can be composed to form complex expressions:
+ *
+ * {{{
+ * $"a" + 1
+ * }}}
+ *
+ * @since 3.4.0
+ */
+class Column private[sql] (@DeveloperApi val expr: proto.Expression) extends Logging {
+
+ private[sql] def this(name: String, planId: Option[Long]) =
+ this(Column.nameToExpression(name, planId))
+
+ private[sql] def this(name: String) =
+ this(name, None)
+
+ private def fn(name: String): Column = Column.fn(name, this)
+ private def fn(name: String, other: Column): Column = Column.fn(name, this, other)
+ private def fn(name: String, other: Any): Column = Column.fn(name, this, lit(other))
+
+ override def toString: String = expr.toString
+
+ override def equals(that: Any): Boolean = that match {
+ case that: Column => expr == that.expr
+ case _ => false
+ }
+
+ override def hashCode: Int = expr.hashCode()
+
+ /**
+ * Provides a type hint about the expected return value of this column. This information can be
+ * used by operations such as `select` on a [[Dataset]] to automatically convert the results
+ * into the correct JVM types.
+ * @since 3.4.0
+ */
+ def as[U: Encoder]: TypedColumn[Any, U] = {
+ val encoder = implicitly[Encoder[U]].asInstanceOf[AgnosticEncoder[U]]
+ new TypedColumn[Any, U](expr, encoder)
+ }
+
+ /**
+ * Extracts a value or values from a complex type. The following types of extraction are
+ * supported:
+ * - Given an Array, an integer ordinal can be used to retrieve a single value.
+ * - Given a Map, a key of the correct type can be used to retrieve an individual value.
+ * - Given a Struct, a string fieldName can be used to extract that field.
+ * - Given an Array of Structs, a string fieldName can be used to extract filed of every
+ * struct in that array, and return an Array of fields.
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def apply(extraction: Any): Column = Column { builder =>
+ builder.getUnresolvedExtractValueBuilder
+ .setChild(expr)
+ .setExtraction(lit(extraction).expr)
+ }
+
+ /**
+ * Unary minus, i.e. negate the expression.
+ * {{{
+ * // Scala: select the amount column and negates all values.
+ * df.select( -df("amount") )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.functions.*;
+ * df.select( negate(col("amount") );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def unary_- : Column = fn("negative")
+
+ /**
+ * Inversion of boolean expression, i.e. NOT.
+ * {{{
+ * // Scala: select rows that are not active (isActive === false)
+ * df.filter( !df("isActive") )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.functions.*;
+ * df.filter( not(df.col("isActive")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def unary_! : Column = fn("!")
+
+ /**
+ * Equality test.
+ * {{{
+ * // Scala:
+ * df.filter( df("colA") === df("colB") )
+ *
+ * // Java
+ * import static org.apache.spark.sql.functions.*;
+ * df.filter( col("colA").equalTo(col("colB")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def ===(other: Any): Column = fn("=", other)
+
+ /**
+ * Equality test.
+ * {{{
+ * // Scala:
+ * df.filter( df("colA") === df("colB") )
+ *
+ * // Java
+ * import static org.apache.spark.sql.functions.*;
+ * df.filter( col("colA").equalTo(col("colB")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def equalTo(other: Any): Column = this === other
+
+ /**
+ * Inequality test.
+ * {{{
+ * // Scala:
+ * df.select( df("colA") =!= df("colB") )
+ * df.select( !(df("colA") === df("colB")) )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.functions.*;
+ * df.filter( col("colA").notEqual(col("colB")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def =!=(other: Any): Column = !(this === other)
+
+ /**
+ * Inequality test.
+ * {{{
+ * // Scala:
+ * df.select( df("colA") !== df("colB") )
+ * df.select( !(df("colA") === df("colB")) )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.functions.*;
+ * df.filter( col("colA").notEqual(col("colB")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ @deprecated("!== does not have the same precedence as ===, use =!= instead", "2.0.0")
+ def !==(other: Any): Column = this =!= other
+
+ /**
+ * Inequality test.
+ * {{{
+ * // Scala:
+ * df.select( df("colA") !== df("colB") )
+ * df.select( !(df("colA") === df("colB")) )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.functions.*;
+ * df.filter( col("colA").notEqual(col("colB")) );
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def notEqual(other: Any): Column = this =!= other
+
+ /**
+ * Greater than.
+ * {{{
+ * // Scala: The following selects people older than 21.
+ * people.select( people("age") > 21 )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.functions.*;
+ * people.select( people.col("age").gt(21) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def >(other: Any): Column = fn(">", other)
+
+ /**
+ * Greater than.
+ * {{{
+ * // Scala: The following selects people older than 21.
+ * people.select( people("age") > lit(21) )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.functions.*;
+ * people.select( people.col("age").gt(21) );
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def gt(other: Any): Column = this > other
+
+ /**
+ * Less than.
+ * {{{
+ * // Scala: The following selects people younger than 21.
+ * people.select( people("age") < 21 )
+ *
+ * // Java:
+ * people.select( people.col("age").lt(21) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def <(other: Any): Column = fn("<", other)
+
+ /**
+ * Less than.
+ * {{{
+ * // Scala: The following selects people younger than 21.
+ * people.select( people("age") < 21 )
+ *
+ * // Java:
+ * people.select( people.col("age").lt(21) );
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def lt(other: Any): Column = this < other
+
+ /**
+ * Less than or equal to.
+ * {{{
+ * // Scala: The following selects people age 21 or younger than 21.
+ * people.select( people("age") <= 21 )
+ *
+ * // Java:
+ * people.select( people.col("age").leq(21) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def <=(other: Any): Column = fn("<=", other)
+
+ /**
+ * Less than or equal to.
+ * {{{
+ * // Scala: The following selects people age 21 or younger than 21.
+ * people.select( people("age") <= 21 )
+ *
+ * // Java:
+ * people.select( people.col("age").leq(21) );
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def leq(other: Any): Column = this <= other
+
+ /**
+ * Greater than or equal to an expression.
+ * {{{
+ * // Scala: The following selects people age 21 or older than 21.
+ * people.select( people("age") >= 21 )
+ *
+ * // Java:
+ * people.select( people.col("age").geq(21) )
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def >=(other: Any): Column = fn(">=", other)
+
+ /**
+ * Greater than or equal to an expression.
+ * {{{
+ * // Scala: The following selects people age 21 or older than 21.
+ * people.select( people("age") >= 21 )
+ *
+ * // Java:
+ * people.select( people.col("age").geq(21) )
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def geq(other: Any): Column = this >= other
+
+ /**
+ * Equality test that is safe for null values.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def <=>(other: Any): Column = fn("<=>", other)
+
+ /**
+ * Equality test that is safe for null values.
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def eqNullSafe(other: Any): Column = this <=> other
+
+ private def extractWhen(name: String): java.util.List[proto.Expression] = {
+ def fail(): Nothing = {
+ throw new IllegalArgumentException(
+ s"$name() can only be applied on a Column previously generated by when() function")
+ }
+ if (!expr.hasUnresolvedFunction) {
+ fail()
+ }
+ val parentFn = expr.getUnresolvedFunction
+ if (parentFn.getFunctionName != "when") {
+ fail()
+ }
+ parentFn.getArgumentsList
+ }
+
+ /**
+ * Evaluates a list of conditions and returns one of multiple possible result expressions. If
+ * otherwise is not defined at the end, null is returned for unmatched conditions.
+ *
+ * {{{
+ * // Example: encoding gender string column into integer.
+ *
+ * // Scala:
+ * people.select(when(people("gender") === "male", 0)
+ * .when(people("gender") === "female", 1)
+ * .otherwise(2))
+ *
+ * // Java:
+ * people.select(when(col("gender").equalTo("male"), 0)
+ * .when(col("gender").equalTo("female"), 1)
+ * .otherwise(2))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def when(condition: Column, value: Any): Column = {
+ val expressions = extractWhen("when")
+ if (expressions.size() % 2 == 1) {
+ throw new IllegalArgumentException("when() cannot be applied once otherwise() is applied")
+ }
+ Column { builder =>
+ builder.getUnresolvedFunctionBuilder
+ .setFunctionName("when")
+ .addAllArguments(expressions)
+ .addArguments(condition.expr)
+ .addArguments(lit(value).expr)
+ }
+ }
+
+ /**
+ * Evaluates a list of conditions and returns one of multiple possible result expressions. If
+ * otherwise is not defined at the end, null is returned for unmatched conditions.
+ *
+ * {{{
+ * // Example: encoding gender string column into integer.
+ *
+ * // Scala:
+ * people.select(when(people("gender") === "male", 0)
+ * .when(people("gender") === "female", 1)
+ * .otherwise(2))
+ *
+ * // Java:
+ * people.select(when(col("gender").equalTo("male"), 0)
+ * .when(col("gender").equalTo("female"), 1)
+ * .otherwise(2))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def otherwise(value: Any): Column = {
+ val expressions = extractWhen("otherwise")
+ if (expressions.size() % 2 == 1) {
+ throw new IllegalArgumentException(
+ "otherwise() can only be applied once on a Column previously generated by when()")
+ }
+ Column { builder =>
+ builder.getUnresolvedFunctionBuilder
+ .setFunctionName("when")
+ .addAllArguments(expressions)
+ .addArguments(lit(value).expr)
+ }
+ }
+
+ /**
+ * True if the current column is between the lower bound and upper bound, inclusive.
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def between(lowerBound: Any, upperBound: Any): Column = {
+ (this >= lowerBound) && (this <= upperBound)
+ }
+
+ /**
+ * True if the current expression is NaN.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def isNaN: Column = fn("isNaN")
+
+ /**
+ * True if the current expression is null.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def isNull: Column = fn("isNull")
+
+ /**
+ * True if the current expression is NOT null.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def isNotNull: Column = fn("isNotNull")
+
+ /**
+ * Boolean OR.
+ * {{{
+ * // Scala: The following selects people that are in school or employed.
+ * people.filter( people("inSchool") || people("isEmployed") )
+ *
+ * // Java:
+ * people.filter( people.col("inSchool").or(people.col("isEmployed")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def ||(other: Any): Column = fn("or", other)
+
+ /**
+ * Boolean OR.
+ * {{{
+ * // Scala: The following selects people that are in school or employed.
+ * people.filter( people("inSchool") || people("isEmployed") )
+ *
+ * // Java:
+ * people.filter( people.col("inSchool").or(people.col("isEmployed")) );
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def or(other: Column): Column = this || other
+
+ /**
+ * Boolean AND.
+ * {{{
+ * // Scala: The following selects people that are in school and employed at the same time.
+ * people.select( people("inSchool") && people("isEmployed") )
+ *
+ * // Java:
+ * people.select( people.col("inSchool").and(people.col("isEmployed")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def &&(other: Any): Column = fn("and", other)
+
+ /**
+ * Boolean AND.
+ * {{{
+ * // Scala: The following selects people that are in school and employed at the same time.
+ * people.select( people("inSchool") && people("isEmployed") )
+ *
+ * // Java:
+ * people.select( people.col("inSchool").and(people.col("isEmployed")) );
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def and(other: Column): Column = this && other
+
+ /**
+ * Sum of this expression and another expression.
+ * {{{
+ * // Scala: The following selects the sum of a person's height and weight.
+ * people.select( people("height") + people("weight") )
+ *
+ * // Java:
+ * people.select( people.col("height").plus(people.col("weight")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def +(other: Any): Column = fn("+", other)
+
+ /**
+ * Sum of this expression and another expression.
+ * {{{
+ * // Scala: The following selects the sum of a person's height and weight.
+ * people.select( people("height") + people("weight") )
+ *
+ * // Java:
+ * people.select( people.col("height").plus(people.col("weight")) );
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def plus(other: Any): Column = this + other
+
+ /**
+ * Subtraction. Subtract the other expression from this expression.
+ * {{{
+ * // Scala: The following selects the difference between people's height and their weight.
+ * people.select( people("height") - people("weight") )
+ *
+ * // Java:
+ * people.select( people.col("height").minus(people.col("weight")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def -(other: Any): Column = fn("-", other)
+
+ /**
+ * Subtraction. Subtract the other expression from this expression.
+ * {{{
+ * // Scala: The following selects the difference between people's height and their weight.
+ * people.select( people("height") - people("weight") )
+ *
+ * // Java:
+ * people.select( people.col("height").minus(people.col("weight")) );
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def minus(other: Any): Column = this - other
+
+ /**
+ * Multiplication of this expression and another expression.
+ * {{{
+ * // Scala: The following multiplies a person's height by their weight.
+ * people.select( people("height") * people("weight") )
+ *
+ * // Java:
+ * people.select( people.col("height").multiply(people.col("weight")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def *(other: Any): Column = fn("*", other)
+
+ /**
+ * Multiplication of this expression and another expression.
+ * {{{
+ * // Scala: The following multiplies a person's height by their weight.
+ * people.select( people("height") * people("weight") )
+ *
+ * // Java:
+ * people.select( people.col("height").multiply(people.col("weight")) );
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def multiply(other: Any): Column = this * other
+
+ /**
+ * Division this expression by another expression.
+ * {{{
+ * // Scala: The following divides a person's height by their weight.
+ * people.select( people("height") / people("weight") )
+ *
+ * // Java:
+ * people.select( people.col("height").divide(people.col("weight")) );
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def /(other: Any): Column = fn("/", other)
+
+ /**
+ * Division this expression by another expression.
+ * {{{
+ * // Scala: The following divides a person's height by their weight.
+ * people.select( people("height") / people("weight") )
+ *
+ * // Java:
+ * people.select( people.col("height").divide(people.col("weight")) );
+ * }}}
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def divide(other: Any): Column = this / other
+
+ /**
+ * Modulo (a.k.a. remainder) expression.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def %(other: Any): Column = fn("%", other)
+
+ /**
+ * Modulo (a.k.a. remainder) expression.
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def mod(other: Any): Column = this % other
+
+ /**
+ * A boolean expression that is evaluated to true if the value of this expression is contained
+ * by the evaluated values of the arguments.
+ *
+ * Note: Since the type of the elements in the list are inferred only during the run time, the
+ * elements will be "up-casted" to the most common type for comparison. For eg: 1) In the case
+ * of "Int vs String", the "Int" will be up-casted to "String" and the comparison will look like
+ * "String vs String". 2) In the case of "Float vs Double", the "Float" will be up-casted to
+ * "Double" and the comparison will look like "Double vs Double"
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def isin(list: Any*): Column = Column.fn("in", this +: list.map(lit): _*)
+
+ /**
+ * A boolean expression that is evaluated to true if the value of this expression is contained
+ * by the provided collection.
+ *
+ * Note: Since the type of the elements in the collection are inferred only during the run time,
+ * the elements will be "up-casted" to the most common type for comparison. For eg: 1) In the
+ * case of "Int vs String", the "Int" will be up-casted to "String" and the comparison will look
+ * like "String vs String". 2) In the case of "Float vs Double", the "Float" will be up-casted
+ * to "Double" and the comparison will look like "Double vs Double"
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*)
+
+ /**
+ * A boolean expression that is evaluated to true if the value of this expression is contained
+ * by the provided collection.
+ *
+ * Note: Since the type of the elements in the collection are inferred only during the run time,
+ * the elements will be "up-casted" to the most common type for comparison. For eg: 1) In the
+ * case of "Int vs String", the "Int" will be up-casted to "String" and the comparison will look
+ * like "String vs String". 2) In the case of "Float vs Double", the "Float" will be up-casted
+ * to "Double" and the comparison will look like "Double vs Double"
+ *
+ * @group java_expr_ops
+ * @since 3.4.0
+ */
+ def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala)
+
+ /**
+ * SQL like expression. Returns a boolean column based on a SQL LIKE match.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def like(literal: String): Column = fn("like", literal)
+
+ /**
+ * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex match.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def rlike(literal: String): Column = fn("rlike", literal)
+
+ /**
+ * SQL ILIKE expression (case insensitive LIKE).
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def ilike(literal: String): Column = fn("ilike", literal)
+
+ /**
+ * An expression that gets an item at position `ordinal` out of an array, or gets a value by key
+ * `key` in a `MapType`.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def getItem(key: Any): Column = apply(key)
+
+ // scalastyle:off line.size.limit
+ /**
+ * An expression that adds/replaces field in `StructType` by name.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".withField("c", lit(3)))
+ * // result: {"a":1,"b":2,"c":3}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".withField("b", lit(3)))
+ * // result: {"a":1,"b":3}
+ *
+ * val df = sql("SELECT CAST(NULL AS struct) struct_col")
+ * df.select($"struct_col".withField("c", lit(3)))
+ * // result: null of type struct
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
+ * df.select($"struct_col".withField("b", lit(100)))
+ * // result: {"a":1,"b":100,"b":100}
+ *
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
+ * df.select($"struct_col".withField("a.c", lit(3)))
+ * // result: {"a":{"a":1,"b":2,"c":3}}
+ *
+ * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col")
+ * df.select($"struct_col".withField("a.c", lit(3)))
+ * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields
+ * }}}
+ *
+ * This method supports adding/replacing nested fields directly e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
+ * df.select($"struct_col".withField("a.c", lit(3)).withField("a.d", lit(4)))
+ * // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+ * }}}
+ *
+ * However, if you are going to add/replace multiple nested fields, it is more optimal to
+ * extract out the nested struct before adding/replacing multiple fields e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
+ * df.select($"struct_col".withField("a", $"struct_col.a".withField("c", lit(3)).withField("d", lit(4))))
+ * // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ // scalastyle:on line.size.limit
+ def withField(fieldName: String, col: Column): Column = {
+ require(fieldName != null, "fieldName cannot be null")
+ require(col != null, "col cannot be null")
+ Column { builder =>
+ builder.getUpdateFieldsBuilder
+ .setStructExpression(expr)
+ .setFieldName(fieldName)
+ .setValueExpression(col.expr)
+ }
+ }
+
+ // scalastyle:off line.size.limit
+ /**
+ * An expression that drops fields in `StructType` by name. This is a no-op if schema doesn't
+ * contain field name(s).
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".dropFields("b"))
+ * // result: {"a":1}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".dropFields("c"))
+ * // result: {"a":1,"b":2}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col")
+ * df.select($"struct_col".dropFields("b", "c"))
+ * // result: {"a":1}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".dropFields("a", "b"))
+ * // result: org.apache.spark.sql.AnalysisException: [DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS] Cannot resolve "update_fields(struct_col, dropfield(), dropfield())" due to data type mismatch: Cannot drop all fields in struct.;
+ *
+ * val df = sql("SELECT CAST(NULL AS struct) struct_col")
+ * df.select($"struct_col".dropFields("b"))
+ * // result: null of type struct
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
+ * df.select($"struct_col".dropFields("b"))
+ * // result: {"a":1}
+ *
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
+ * df.select($"struct_col".dropFields("a.b"))
+ * // result: {"a":{"a":1}}
+ *
+ * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col")
+ * df.select($"struct_col".dropFields("a.c"))
+ * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields
+ * }}}
+ *
+ * This method supports dropping multiple nested fields directly e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
+ * df.select($"struct_col".dropFields("a.b", "a.c"))
+ * // result: {"a":{"a":1}}
+ * }}}
+ *
+ * However, if you are going to drop multiple nested fields, it is more optimal to extract out
+ * the nested struct before dropping multiple fields from it e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
+ * df.select($"struct_col".withField("a", $"struct_col.a".dropFields("b", "c")))
+ * // result: {"a":{"a":1}}
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ // scalastyle:on line.size.limit
+ def dropFields(fieldNames: String*): Column = {
+ fieldNames.foldLeft(this) { case (column, fieldName) =>
+ Column { builder =>
+ builder.getUpdateFieldsBuilder
+ .setStructExpression(column.expr)
+ .setFieldName(fieldName)
+ }
+ }
+ }
+
+ /**
+ * An expression that gets a field by name in a `StructType`.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def getField(fieldName: String): Column = apply(fieldName)
+
+ /**
+ * An expression that returns a substring.
+ * @param startPos
+ * expression for the starting position.
+ * @param len
+ * expression for the length of the substring.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def substr(startPos: Column, len: Column): Column = Column.fn("substr", this, startPos, len)
+
+ /**
+ * An expression that returns a substring.
+ * @param startPos
+ * starting position.
+ * @param len
+ * length of the substring.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def substr(startPos: Int, len: Int): Column = substr(lit(startPos), lit(len))
+
+ /**
+ * Contains the other element. Returns a boolean column based on a string match.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def contains(other: Any): Column = fn("contains", other)
+
+ /**
+ * String starts with. Returns a boolean column based on a string match.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def startsWith(other: Column): Column = fn("startswith", other)
+
+ /**
+ * String starts with another string literal. Returns a boolean column based on a string match.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def startsWith(literal: String): Column = startsWith(lit(literal))
+
+ /**
+ * String ends with. Returns a boolean column based on a string match.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def endsWith(other: Column): Column = fn("endswith", other)
+
+ /**
+ * String ends with another string literal. Returns a boolean column based on a string match.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def endsWith(literal: String): Column = endsWith(lit(literal))
+
+ /**
+ * Gives the column an alias. Same as `as`.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select($"colA".alias("colB"))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def alias(alias: String): Column = name(alias)
+
+ /**
+ * Gives the column an alias.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select($"colA".as("colB"))
+ * }}}
+ *
+ * If the current column has metadata associated with it, this metadata will be propagated to
+ * the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with
+ * explicit metadata.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def as(alias: String): Column = name(alias)
+
+ /**
+ * (Scala-specific) Assigns the given aliases to the results of a table generating function.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select(explode($"myMap").as("key" :: "value" :: Nil))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def as(aliases: Seq[String]): Column = Column { builder =>
+ builder.getAliasBuilder.setExpr(expr).addAllName(aliases.asJava)
+ }
+
+ /**
+ * Assigns the given aliases to the results of a table generating function.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select(explode($"myMap").as("key" :: "value" :: Nil))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def as(aliases: Array[String]): Column = as(aliases.toSeq)
+
+ /**
+ * Gives the column an alias.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select($"colA".as("colB"))
+ * }}}
+ *
+ * If the current column has metadata associated with it, this metadata will be propagated to
+ * the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with
+ * explicit metadata.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def as(alias: Symbol): Column = name(alias.name)
+
+ /**
+ * Gives the column an alias with metadata.
+ * {{{
+ * val metadata: Metadata = ...
+ * df.select($"colA".as("colB", metadata))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def as(alias: String, metadata: Metadata): Column = Column { builder =>
+ builder.getAliasBuilder
+ .setExpr(expr)
+ .addName(alias)
+ .setMetadata(metadata.json)
+ }
+
+ /**
+ * Gives the column a name (alias).
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select($"colA".name("colB"))
+ * }}}
+ *
+ * If the current column has metadata associated with it, this metadata will be propagated to
+ * the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with
+ * explicit metadata.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def name(alias: String): Column = as(alias :: Nil)
+
+ /**
+ * Casts the column to a different data type.
+ * {{{
+ * // Casts colA to IntegerType.
+ * import org.apache.spark.sql.types.IntegerType
+ * df.select(df("colA").cast(IntegerType))
+ *
+ * // equivalent to
+ * df.select(df("colA").cast("int"))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def cast(to: DataType): Column = Column { builder =>
+ builder.getCastBuilder
+ .setExpr(expr)
+ .setType(DataTypeProtoConverter.toConnectProtoType(to))
+ }
+
+ /**
+ * Casts the column to a different data type, using the canonical string representation of the
+ * type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`, `float`,
+ * `double`, `decimal`, `date`, `timestamp`.
+ * {{{
+ * // Casts colA to integer.
+ * df.select(df("colA").cast("int"))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def cast(to: String): Column = cast(CatalystSqlParser.parseDataType(to))
+
+ /**
+ * Returns a sort expression based on the descending order of the column.
+ * {{{
+ * // Scala
+ * df.sort(df("age").desc)
+ *
+ * // Java
+ * df.sort(df.col("age").desc());
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def desc: Column = desc_nulls_last
+
+ /**
+ * Returns a sort expression based on the descending order of the column, and null values appear
+ * before non-null values.
+ * {{{
+ * // Scala: sort a DataFrame by age column in descending order and null values appearing first.
+ * df.sort(df("age").desc_nulls_first)
+ *
+ * // Java
+ * df.sort(df.col("age").desc_nulls_first());
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def desc_nulls_first: Column =
+ buildSortOrder(SortDirection.SORT_DIRECTION_DESCENDING, NullOrdering.SORT_NULLS_FIRST)
+
+ /**
+ * Returns a sort expression based on the descending order of the column, and null values appear
+ * after non-null values.
+ * {{{
+ * // Scala: sort a DataFrame by age column in descending order and null values appearing last.
+ * df.sort(df("age").desc_nulls_last)
+ *
+ * // Java
+ * df.sort(df.col("age").desc_nulls_last());
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def desc_nulls_last: Column =
+ buildSortOrder(SortDirection.SORT_DIRECTION_DESCENDING, NullOrdering.SORT_NULLS_LAST)
+
+ /**
+ * Returns a sort expression based on ascending order of the column.
+ * {{{
+ * // Scala: sort a DataFrame by age column in ascending order.
+ * df.sort(df("age").asc)
+ *
+ * // Java
+ * df.sort(df.col("age").asc());
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def asc: Column = asc_nulls_first
+
+ /**
+ * Returns a sort expression based on ascending order of the column, and null values return
+ * before non-null values.
+ * {{{
+ * // Scala: sort a DataFrame by age column in ascending order and null values appearing first.
+ * df.sort(df("age").asc_nulls_first)
+ *
+ * // Java
+ * df.sort(df.col("age").asc_nulls_first());
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def asc_nulls_first: Column =
+ buildSortOrder(SortDirection.SORT_DIRECTION_ASCENDING, NullOrdering.SORT_NULLS_FIRST)
+
+ /**
+ * Returns a sort expression based on ascending order of the column, and null values appear
+ * after non-null values.
+ * {{{
+ * // Scala: sort a DataFrame by age column in ascending order and null values appearing last.
+ * df.sort(df("age").asc_nulls_last)
+ *
+ * // Java
+ * df.sort(df.col("age").asc_nulls_last());
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def asc_nulls_last: Column =
+ buildSortOrder(SortDirection.SORT_DIRECTION_ASCENDING, NullOrdering.SORT_NULLS_LAST)
+
+ private def buildSortOrder(sortDirection: SortDirection, nullOrdering: NullOrdering): Column = {
+ Column { builder =>
+ builder.getSortOrderBuilder
+ .setChild(expr)
+ .setDirection(sortDirection)
+ .setNullOrdering(nullOrdering)
+ }
+ }
+
+ private[sql] def sortOrder: proto.Expression.SortOrder = {
+ val base = if (expr.hasSortOrder) {
+ expr
+ } else {
+ asc.expr
+ }
+ base.getSortOrder
+ }
+
+ /**
+ * Prints the expression to the console for debugging purposes.
+ *
+ * @group df_ops
+ * @since 3.4.0
+ */
+ def explain(extended: Boolean): Unit = {
+ // scalastyle:off println
+ if (extended) {
+ println(expr)
+ } else {
+ println(toString)
+ }
+ // scalastyle:on println
+ }
+
+ /**
+ * Compute bitwise OR of this expression with another expression.
+ * {{{
+ * df.select($"colA".bitwiseOR($"colB"))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def bitwiseOR(other: Any): Column = fn("|", other)
+
+ /**
+ * Compute bitwise AND of this expression with another expression.
+ * {{{
+ * df.select($"colA".bitwiseAND($"colB"))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def bitwiseAND(other: Any): Column = fn("&", other)
+
+ /**
+ * Compute bitwise XOR of this expression with another expression.
+ * {{{
+ * df.select($"colA".bitwiseXOR($"colB"))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def bitwiseXOR(other: Any): Column = fn("^", other)
+
+ /**
+ * Defines a windowing column.
+ *
+ * {{{
+ * val w = Window.partitionBy("name").orderBy("id")
+ * df.select(
+ * sum("price").over(w.rangeBetween(Window.unboundedPreceding, 2)),
+ * avg("price").over(w.rowsBetween(Window.currentRow, 4))
+ * )
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def over(window: expressions.WindowSpec): Column = window.withAggregate(this)
+
+ /**
+ * Defines an empty analytic clause. In this case the analytic function is applied and presented
+ * for all rows in the result set.
+ *
+ * {{{
+ * df.select(
+ * sum("price").over(),
+ * avg("price").over()
+ * )
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ def over(): Column = over(Window.spec)
+}
+
+private[sql] object Column {
+
+ def apply(name: String): Column = new Column(name)
+
+ def apply(name: String, planId: Option[Long]): Column = new Column(name, planId)
+
+ def nameToExpression(name: String, planId: Option[Long] = None): proto.Expression = {
+ val builder = proto.Expression.newBuilder()
+ name match {
+ case "*" =>
+ builder.getUnresolvedStarBuilder
+ case _ if name.endsWith(".*") =>
+ builder.getUnresolvedStarBuilder.setUnparsedTarget(name)
+ case _ =>
+ val attributeBuilder = builder.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name)
+ planId.foreach(attributeBuilder.setPlanId)
+ }
+ builder.build()
+ }
+
+ private[sql] def apply(f: proto.Expression.Builder => Unit): Column = {
+ val builder = proto.Expression.newBuilder()
+ f(builder)
+ new Column(builder.build())
+ }
+
+ @DeveloperApi
+ def apply(extension: com.google.protobuf.Any): Column = {
+ apply(_.setExtension(extension))
+ }
+
+ private[sql] def fn(name: String, inputs: Column*): Column = {
+ fn(name, isDistinct = false, inputs: _*)
+ }
+
+ private[sql] def fn(name: String, isDistinct: Boolean, inputs: Column*): Column = Column {
+ builder =>
+ builder.getUnresolvedFunctionBuilder
+ .setFunctionName(name)
+ .setIsDistinct(isDistinct)
+ .addAllArguments(inputs.map(_.expr).asJava)
+ }
+}
+
+/**
+ * A convenient class used for constructing schema.
+ *
+ * @since 3.4.0
+ */
+class ColumnName(name: String) extends Column(name) {
+
+ /**
+ * Creates a new `StructField` of type boolean.
+ * @since 3.4.0
+ */
+ def boolean: StructField = StructField(name, BooleanType)
+
+ /**
+ * Creates a new `StructField` of type byte.
+ * @since 3.4.0
+ */
+ def byte: StructField = StructField(name, ByteType)
+
+ /**
+ * Creates a new `StructField` of type short.
+ * @since 3.4.0
+ */
+ def short: StructField = StructField(name, ShortType)
+
+ /**
+ * Creates a new `StructField` of type int.
+ * @since 3.4.0
+ */
+ def int: StructField = StructField(name, IntegerType)
+
+ /**
+ * Creates a new `StructField` of type long.
+ * @since 3.4.0
+ */
+ def long: StructField = StructField(name, LongType)
+
+ /**
+ * Creates a new `StructField` of type float.
+ * @since 3.4.0
+ */
+ def float: StructField = StructField(name, FloatType)
+
+ /**
+ * Creates a new `StructField` of type double.
+ * @since 3.4.0
+ */
+ def double: StructField = StructField(name, DoubleType)
+
+ /**
+ * Creates a new `StructField` of type string.
+ * @since 3.4.0
+ */
+ def string: StructField = StructField(name, StringType)
+
+ /**
+ * Creates a new `StructField` of type date.
+ * @since 3.4.0
+ */
+ def date: StructField = StructField(name, DateType)
+
+ /**
+ * Creates a new `StructField` of type decimal.
+ * @since 3.4.0
+ */
+ def decimal: StructField = StructField(name, DecimalType.USER_DEFAULT)
+
+ /**
+ * Creates a new `StructField` of type decimal.
+ * @since 3.4.0
+ */
+ def decimal(precision: Int, scale: Int): StructField =
+ StructField(name, DecimalType(precision, scale))
+
+ /**
+ * Creates a new `StructField` of type timestamp.
+ * @since 3.4.0
+ */
+ def timestamp: StructField = StructField(name, TimestampType)
+
+ /**
+ * Creates a new `StructField` of type binary.
+ * @since 3.4.0
+ */
+ def binary: StructField = StructField(name, BinaryType)
+
+ /**
+ * Creates a new `StructField` of type array.
+ * @since 3.4.0
+ */
+ def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType))
+
+ /**
+ * Creates a new `StructField` of type map.
+ * @since 3.4.0
+ */
+ def map(keyType: DataType, valueType: DataType): StructField =
+ map(MapType(keyType, valueType))
+
+ /**
+ * Creates a new `StructField` of type map.
+ * @since 3.4.0
+ */
+ def map(mapType: MapType): StructField = StructField(name, mapType)
+
+ /**
+ * Creates a new `StructField` of type struct.
+ * @since 3.4.0
+ */
+ def struct(fields: StructField*): StructField = struct(StructType(fields))
+
+ /**
+ * Creates a new `StructField` of type struct.
+ * @since 3.4.0
+ */
+ def struct(structType: StructType): StructField = StructField(name, structType)
+}
+
+/**
+ * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. To
+ * create a [[TypedColumn]], use the `as` function on a [[Column]].
+ *
+ * @tparam T
+ * The input type expected for this expression. Can be `Any` if the expression is type checked
+ * by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
+ * @tparam U
+ * The output type of this column.
+ *
+ * @since 3.4.0
+ */
+class TypedColumn[-T, U] private[sql] (
+ expr: proto.Expression,
+ private[sql] val encoder: AgnosticEncoder[U])
+ extends Column(expr) {
+
+ /**
+ * Gives the [[TypedColumn]] a name (alias). If the current `TypedColumn` has metadata
+ * associated with it, this metadata will be propagated to the new column.
+ *
+ * @group expr_ops
+ * @since 3.4.0
+ */
+ override def name(alias: String): TypedColumn[T, U] =
+ new TypedColumn[T, U](super.name(alias).expr, encoder)
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
new file mode 100644
index 0000000000000..17b95018f8986
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -0,0 +1,441 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto.{NAReplace, Relation}
+import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
+import org.apache.spark.connect.proto.NAReplace.Replacement
+
+/**
+ * Functionality for working with missing data in `DataFrame`s.
+ *
+ * @since 3.4.0
+ */
+final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation) {
+
+ /**
+ * Returns a new `DataFrame` that drops rows containing any null or NaN values.
+ *
+ * @since 3.4.0
+ */
+ def drop(): DataFrame = buildDropDataFrame(None, None)
+
+ /**
+ * Returns a new `DataFrame` that drops rows containing null or NaN values.
+ *
+ * If `how` is "any", then drop rows containing any null or NaN values. If `how` is "all", then
+ * drop rows only if every column is null or NaN for that row.
+ *
+ * @since 3.4.0
+ */
+ def drop(how: String): DataFrame = {
+ buildDropDataFrame(None, buildMinNonNulls(how))
+ }
+
+ /**
+ * Returns a new `DataFrame` that drops rows containing any null or NaN values in the specified
+ * columns.
+ *
+ * @since 3.4.0
+ */
+ def drop(cols: Array[String]): DataFrame = drop(cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values
+ * in the specified columns.
+ *
+ * @since 3.4.0
+ */
+ def drop(cols: Seq[String]): DataFrame = buildDropDataFrame(Some(cols), None)
+
+ /**
+ * Returns a new `DataFrame` that drops rows containing null or NaN values in the specified
+ * columns.
+ *
+ * If `how` is "any", then drop rows containing any null or NaN values in the specified columns.
+ * If `how` is "all", then drop rows only if every specified column is null or NaN for that row.
+ *
+ * @since 3.4.0
+ */
+ def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values in
+ * the specified columns.
+ *
+ * If `how` is "any", then drop rows containing any null or NaN values in the specified columns.
+ * If `how` is "all", then drop rows only if every specified column is null or NaN for that row.
+ *
+ * @since 3.4.0
+ */
+ def drop(how: String, cols: Seq[String]): DataFrame = {
+ buildDropDataFrame(Some(cols), buildMinNonNulls(how))
+ }
+
+ /**
+ * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and
+ * non-NaN values.
+ *
+ * @since 3.4.0
+ */
+ def drop(minNonNulls: Int): DataFrame = {
+ buildDropDataFrame(None, Some(minNonNulls))
+ }
+
+ /**
+ * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and
+ * non-NaN values in the specified columns.
+ *
+ * @since 3.4.0
+ */
+ def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new `DataFrame` that drops rows containing less than `minNonNulls`
+ * non-null and non-NaN values in the specified columns.
+ *
+ * @since 3.4.0
+ */
+ def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
+ buildDropDataFrame(Some(cols), Some(minNonNulls))
+ }
+
+ private def buildMinNonNulls(how: String): Option[Int] = {
+ how.toLowerCase(Locale.ROOT) match {
+ case "any" => None // No-Op. Do nothing.
+ case "all" => Some(1)
+ case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'")
+ }
+ }
+
+ private def buildDropDataFrame(
+ cols: Option[Seq[String]],
+ minNonNulls: Option[Int]): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ val dropNaBuilder = builder.getDropNaBuilder.setInput(root)
+ cols.foreach(c => dropNaBuilder.addAllCols(c.asJava))
+ minNonNulls.foreach(dropNaBuilder.setMinNonNulls)
+ }
+ }
+
+ /**
+ * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: Long): DataFrame = {
+ buildFillDataFrame(None, GLiteral.newBuilder().setLong(value).build())
+ }
+
+ /**
+ * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a
+ * specified column is not a numeric column, it is ignored.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
+ * numeric columns. If a specified column is not a numeric column, it is ignored.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: Long, cols: Seq[String]): DataFrame = {
+ buildFillDataFrame(Some(cols), GLiteral.newBuilder().setLong(value).build())
+ }
+
+ /**
+ * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: Double): DataFrame = {
+ buildFillDataFrame(None, GLiteral.newBuilder().setDouble(value).build())
+ }
+
+ /**
+ * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a
+ * specified column is not a numeric column, it is ignored.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
+ * numeric columns. If a specified column is not a numeric column, it is ignored.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: Double, cols: Seq[String]): DataFrame = {
+ buildFillDataFrame(Some(cols), GLiteral.newBuilder().setDouble(value).build())
+ }
+
+ /**
+ * Returns a new `DataFrame` that replaces null values in string columns with `value`.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: String): DataFrame = {
+ buildFillDataFrame(None, GLiteral.newBuilder().setString(value).build())
+ }
+
+ /**
+ * Returns a new `DataFrame` that replaces null values in specified string columns. If a
+ * specified column is not a string column, it is ignored.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified string
+ * columns. If a specified column is not a string column, it is ignored.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: String, cols: Seq[String]): DataFrame = {
+ buildFillDataFrame(Some(cols), GLiteral.newBuilder().setString(value).build())
+ }
+
+ /**
+ * Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: Boolean): DataFrame = {
+ buildFillDataFrame(None, GLiteral.newBuilder().setBoolean(value).build())
+ }
+
+ /**
+ * Returns a new `DataFrame` that replaces null values in specified boolean columns. If a
+ * specified column is not a boolean column, it is ignored.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified boolean
+ * columns. If a specified column is not a boolean column, it is ignored.
+ *
+ * @since 3.4.0
+ */
+ def fill(value: Boolean, cols: Seq[String]): DataFrame = {
+ buildFillDataFrame(Some(cols), GLiteral.newBuilder().setBoolean(value).build())
+ }
+
+ private def buildFillDataFrame(cols: Option[Seq[String]], value: GLiteral): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ val fillNaBuilder = builder.getFillNaBuilder.setInput(root)
+ fillNaBuilder.addValues(value)
+ cols.foreach(c => fillNaBuilder.addAllCols(c.asJava))
+ }
+ }
+
+ /**
+ * Returns a new `DataFrame` that replaces null values.
+ *
+ * The key of the map is the column name, and the value of the map is the replacement value. The
+ * value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`,
+ * `Boolean`. Replacement values are cast to the column data type.
+ *
+ * For example, the following replaces null values in column "A" with string "unknown", and null
+ * values in column "B" with numeric value 1.0.
+ * {{{
+ * import com.google.common.collect.ImmutableMap;
+ * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0));
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def fill(valueMap: java.util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq)
+
+ /**
+ * Returns a new `DataFrame` that replaces null values.
+ *
+ * The key of the map is the column name, and the value of the map is the replacement value. The
+ * value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`,
+ * `Boolean`. Replacement values are cast to the column data type.
+ *
+ * For example, the following replaces null values in column "A" with string "unknown", and null
+ * values in column "B" with numeric value 1.0.
+ * {{{
+ * import com.google.common.collect.ImmutableMap;
+ * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0));
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq)
+
+ private def fillMap(values: Seq[(String, Any)]): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ val fillNaBuilder = builder.getFillNaBuilder.setInput(root)
+ values.map { case (colName, replaceValue) =>
+ fillNaBuilder.addCols(colName).addValues(functions.lit(replaceValue).expr.getLiteral)
+ }
+ }
+ }
+
+ /**
+ * Replaces values matching keys in `replacement` map with the corresponding values.
+ *
+ * {{{
+ * import com.google.common.collect.ImmutableMap;
+ *
+ * // Replaces all occurrences of 1.0 with 2.0 in column "height".
+ * df.na.replace("height", ImmutableMap.of(1.0, 2.0));
+ *
+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
+ * df.na.replace("name", ImmutableMap.of("UNKNOWN", "unnamed"));
+ *
+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
+ * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
+ * }}}
+ *
+ * @param col
+ * name of the column to apply the value replacement. If `col` is "*", replacement is applied
+ * on all string, numeric or boolean columns.
+ * @param replacement
+ * value replacement map. Key and value of `replacement` map must have the same type, and can
+ * only be doubles, strings or booleans. The map value can have nulls.
+ * @since 3.4.0
+ */
+ def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame =
+ replace(col, replacement.asScala.toMap)
+
+ /**
+ * (Scala-specific) Replaces values matching keys in `replacement` map.
+ *
+ * {{{
+ * // Replaces all occurrences of 1.0 with 2.0 in column "height".
+ * df.na.replace("height", Map(1.0 -> 2.0));
+ *
+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
+ * df.na.replace("name", Map("UNKNOWN" -> "unnamed"));
+ *
+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
+ * df.na.replace("*", Map("UNKNOWN" -> "unnamed"));
+ * }}}
+ *
+ * @param col
+ * name of the column to apply the value replacement. If `col` is "*", replacement is applied
+ * on all string, numeric or boolean columns.
+ * @param replacement
+ * value replacement map. Key and value of `replacement` map must have the same type, and can
+ * only be doubles, strings or booleans. The map value can have nulls.
+ * @since 3.4.0
+ */
+ def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
+ val cols = if (col != "*") Some(Seq(col)) else None
+ buildReplaceDataFrame(cols, buildReplacement(replacement))
+ }
+
+ /**
+ * Replaces values matching keys in `replacement` map with the corresponding values.
+ *
+ * {{{
+ * import com.google.common.collect.ImmutableMap;
+ *
+ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
+ * df.na.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0));
+ *
+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
+ * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
+ * }}}
+ *
+ * @param cols
+ * list of columns to apply the value replacement. If `col` is "*", replacement is applied on
+ * all string, numeric or boolean columns.
+ * @param replacement
+ * value replacement map. Key and value of `replacement` map must have the same type, and can
+ * only be doubles, strings or booleans. The map value can have nulls.
+ * @since 3.4.0
+ */
+ def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = {
+ replace(cols.toSeq, replacement.asScala.toMap)
+ }
+
+ /**
+ * (Scala-specific) Replaces values matching keys in `replacement` map.
+ *
+ * {{{
+ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
+ * df.na.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0));
+ *
+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
+ * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"));
+ * }}}
+ *
+ * @param cols
+ * list of columns to apply the value replacement. If `col` is "*", replacement is applied on
+ * all string, numeric or boolean columns.
+ * @param replacement
+ * value replacement map. Key and value of `replacement` map must have the same type, and can
+ * only be doubles, strings or booleans. The map value can have nulls.
+ * @since 3.4.0
+ */
+ def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
+ buildReplaceDataFrame(Some(cols), buildReplacement(replacement))
+ }
+
+ private def buildReplaceDataFrame(
+ cols: Option[Seq[String]],
+ replacements: Iterable[NAReplace.Replacement]): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ val replaceBuilder = builder.getReplaceBuilder.setInput(root)
+ replaceBuilder.addAllReplacements(replacements.asJava)
+ cols.foreach(c => replaceBuilder.addAllCols(c.asJava))
+ }
+ }
+
+ private def buildReplacement[T](replacement: Map[T, T]): Iterable[NAReplace.Replacement] = {
+ // Convert the NumericType in replacement map to DoubleType,
+ // while leaving StringType, BooleanType and null untouched.
+ val replacementMap: Map[_, _] = replacement.map {
+ case (k, v: String) => (k, v)
+ case (k, v: Boolean) => (k, v)
+ case (k: String, null) => (k, null)
+ case (k: Boolean, null) => (k, null)
+ case (k, null) => (convertToDouble(k), null)
+ case (k, v) => (convertToDouble(k), convertToDouble(v))
+ }
+ replacementMap.map { case (oldValue, newValue) =>
+ Replacement
+ .newBuilder()
+ .setOldValue(functions.lit(oldValue).expr.getLiteral)
+ .setNewValue(functions.lit(newValue).expr.getLiteral)
+ .build()
+ }
+ }
+
+ private def convertToDouble(v: Any): Double = v match {
+ case v: Float => v.toDouble
+ case v: Double => v
+ case v: Long => v.toDouble
+ case v: Int => v.toDouble
+ case v =>
+ throw new IllegalArgumentException(s"Unsupported value type ${v.getClass.getName} ($v).")
+ }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
new file mode 100644
index 0000000000000..40f9ac1df2b22
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -0,0 +1,580 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Properties
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.Stable
+import org.apache.spark.connect.proto.Parse.ParseFormat
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
+import org.apache.spark.sql.connect.common.DataTypeProtoConverter
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Interface used to load a [[Dataset]] from external storage systems (e.g. file systems,
+ * key-value stores, etc). Use `SparkSession.read` to access this.
+ *
+ * @since 3.4.0
+ */
+@Stable
+class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging {
+
+ /**
+ * Specifies the input data source format.
+ *
+ * @since 3.4.0
+ */
+ def format(source: String): DataFrameReader = {
+ this.source = source
+ this
+ }
+
+ /**
+ * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema
+ * automatically from data. By specifying the schema here, the underlying data source can skip
+ * the schema inference step, and thus speed up data loading.
+ *
+ * @since 3.4.0
+ */
+ def schema(schema: StructType): DataFrameReader = {
+ if (schema != null) {
+ val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
+ this.userSpecifiedSchema = Option(replaced)
+ }
+ this
+ }
+
+ /**
+ * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON)
+ * can infer the input schema automatically from data. By specifying the schema here, the
+ * underlying data source can skip the schema inference step, and thus speed up data loading.
+ *
+ * {{{
+ * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv")
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def schema(schemaString: String): DataFrameReader = {
+ schema(StructType.fromDDL(schemaString))
+ }
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: String): DataFrameReader = {
+ this.extraOptions = this.extraOptions + (key -> value)
+ this
+ }
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString)
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Long): DataFrameReader = option(key, value.toString)
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Double): DataFrameReader = option(key, value.toString)
+
+ /**
+ * (Scala-specific) Adds input options for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def options(options: scala.collection.Map[String, String]): DataFrameReader = {
+ this.extraOptions ++= options
+ this
+ }
+
+ /**
+ * Adds input options for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def options(options: java.util.Map[String, String]): DataFrameReader = {
+ this.options(options.asScala)
+ this
+ }
+
+ /**
+ * Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
+ * key-value stores).
+ *
+ * @since 3.4.0
+ */
+ def load(): DataFrame = {
+ load(Seq.empty: _*) // force invocation of `load(...varargs...)`
+ }
+
+ /**
+ * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a
+ * local or distributed file system).
+ *
+ * @since 3.4.0
+ */
+ def load(path: String): DataFrame = {
+ // force invocation of `load(...varargs...)`
+ load(Seq(path): _*)
+ }
+
+ /**
+ * Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if
+ * the source is a HadoopFsRelationProvider.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def load(paths: String*): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ val dataSourceBuilder = builder.getReadBuilder.getDataSourceBuilder
+ assertSourceFormatSpecified()
+ dataSourceBuilder.setFormat(source)
+ userSpecifiedSchema.foreach(schema => dataSourceBuilder.setSchema(schema.toDDL))
+ extraOptions.foreach { case (k, v) =>
+ dataSourceBuilder.putOptions(k, v)
+ }
+ paths.foreach(path => dataSourceBuilder.addPaths(path))
+ builder.build()
+ }
+ }
+
+ /**
+ * Construct a `DataFrame` representing the database table accessible via JDBC URL url named
+ * table and connection properties.
+ *
+ * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
+ * in
+ * Data Source Option in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def jdbc(url: String, table: String, properties: Properties): DataFrame = {
+ // properties should override settings in extraOptions.
+ this.extraOptions ++= properties.asScala
+ // explicit url and dbtable should override all
+ this.extraOptions ++= Seq("url" -> url, "dbtable" -> table)
+ format("jdbc").load()
+ }
+
+ // scalastyle:off line.size.limit
+ /**
+ * Construct a `DataFrame` representing the database table accessible via JDBC URL url named
+ * table. Partitions of the table will be retrieved in parallel based on the parameters passed
+ * to this function.
+ *
+ * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
+ * your external database systems.
+ *
+ * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
+ * in
+ * Data Source Option in the version you use.
+ *
+ * @param table
+ * Name of the table in the external database.
+ * @param columnName
+ * Alias of `partitionColumn` option. Refer to `partitionColumn` in
+ * Data Source Option in the version you use.
+ * @param connectionProperties
+ * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least
+ * a "user" and "password" property should be included. "fetchsize" can be used to control the
+ * number of rows per fetch and "queryTimeout" can be used to wait for a Statement object to
+ * execute to the given number of seconds.
+ * @since 3.4.0
+ */
+ // scalastyle:on line.size.limit
+ def jdbc(
+ url: String,
+ table: String,
+ columnName: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int,
+ connectionProperties: Properties): DataFrame = {
+ // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions.
+ this.extraOptions ++= Map(
+ "partitionColumn" -> columnName,
+ "lowerBound" -> lowerBound.toString,
+ "upperBound" -> upperBound.toString,
+ "numPartitions" -> numPartitions.toString)
+ jdbc(url, table, connectionProperties)
+ }
+
+ /**
+ * Construct a `DataFrame` representing the database table accessible via JDBC URL url named
+ * table using connection properties. The `predicates` parameter gives a list expressions
+ * suitable for inclusion in WHERE clauses; each one defines one partition of the `DataFrame`.
+ *
+ * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
+ * your external database systems.
+ *
+ * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
+ * in
+ * Data Source Option in the version you use.
+ *
+ * @param table
+ * Name of the table in the external database.
+ * @param predicates
+ * Condition in the where clause for each partition.
+ * @param connectionProperties
+ * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least
+ * a "user" and "password" property should be included. "fetchsize" can be used to control the
+ * number of rows per fetch.
+ * @since 3.4.0
+ */
+ def jdbc(
+ url: String,
+ table: String,
+ predicates: Array[String],
+ connectionProperties: Properties): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ val dataSourceBuilder = builder.getReadBuilder.getDataSourceBuilder
+ format("jdbc")
+ dataSourceBuilder.setFormat(source)
+ predicates.foreach(predicate => dataSourceBuilder.addPredicates(predicate))
+ this.extraOptions ++= Seq("url" -> url, "dbtable" -> table)
+ val params = extraOptions ++ connectionProperties.asScala
+ params.foreach { case (k, v) =>
+ dataSourceBuilder.putOptions(k, v)
+ }
+ builder.build()
+ }
+ }
+
+ /**
+ * Loads a JSON file and returns the results as a `DataFrame`.
+ *
+ * See the documentation on the overloaded `json()` method with varargs for more details.
+ *
+ * @since 3.4.0
+ */
+ def json(path: String): DataFrame = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ json(Seq(path): _*)
+ }
+
+ /**
+ * Loads JSON files and returns the results as a `DataFrame`.
+ *
+ * JSON Lines (newline-delimited JSON) is supported by
+ * default. For JSON (one record per file), set the `multiLine` option to true.
+ *
+ * This function goes through the input once to determine the input schema. If you know the
+ * schema in advance, use the version that specifies the schema to avoid the extra scan.
+ *
+ * You can find the JSON-specific options for reading JSON files in
+ * Data Source Option in the version you use.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def json(paths: String*): DataFrame = {
+ format("json").load(paths: _*)
+ }
+
+ /**
+ * Loads a `Dataset[String]` storing JSON objects (JSON Lines
+ * text format or newline-delimited JSON) and returns the result as a `DataFrame`.
+ *
+ * Unless the schema is specified using `schema` function, this function goes through the input
+ * once to determine the input schema.
+ *
+ * @param jsonDataset
+ * input Dataset with one JSON object per record
+ * @since 3.4.0
+ */
+ def json(jsonDataset: Dataset[String]): DataFrame =
+ parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON)
+
+ /**
+ * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other
+ * overloaded `csv()` method for more details.
+ *
+ * @since 3.4.0
+ */
+ def csv(path: String): DataFrame = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ csv(Seq(path): _*)
+ }
+
+ /**
+ * Loads CSV files and returns the result as a `DataFrame`.
+ *
+ * This function will go through the input once to determine the input schema if `inferSchema`
+ * is enabled. To avoid going through the entire data once, disable `inferSchema` option or
+ * specify the schema explicitly using `schema`.
+ *
+ * You can find the CSV-specific options for reading CSV files in
+ * Data Source Option in the version you use.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def csv(paths: String*): DataFrame = format("csv").load(paths: _*)
+
+ /**
+ * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`.
+ *
+ * If the schema is not specified using `schema` function and `inferSchema` option is enabled,
+ * this function goes through the input once to determine the input schema.
+ *
+ * If the schema is not specified using `schema` function and `inferSchema` option is disabled,
+ * it determines the columns as string types and it reads only the first line to determine the
+ * names and the number of fields.
+ *
+ * If the enforceSchema is set to `false`, only the CSV header in the first line is checked to
+ * conform specified or inferred schema.
+ *
+ * @note
+ * if `header` option is set to `true` when calling this API, all lines same with the header
+ * will be removed if exists.
+ * @param csvDataset
+ * input Dataset with one CSV row per record
+ * @since 3.4.0
+ */
+ def csv(csvDataset: Dataset[String]): DataFrame =
+ parse(csvDataset, ParseFormat.PARSE_FORMAT_CSV)
+
+ /**
+ * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the
+ * other overloaded `parquet()` method for more details.
+ *
+ * @since 3.4.0
+ */
+ def parquet(path: String): DataFrame = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ parquet(Seq(path): _*)
+ }
+
+ /**
+ * Loads a Parquet file, returning the result as a `DataFrame`.
+ *
+ * Parquet-specific option(s) for reading Parquet files can be found in Data
+ * Source Option in the version you use.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def parquet(paths: String*): DataFrame = {
+ format("parquet").load(paths: _*)
+ }
+
+ /**
+ * Loads an ORC file and returns the result as a `DataFrame`.
+ *
+ * @param path
+ * input path
+ * @since 3.4.0
+ */
+ def orc(path: String): DataFrame = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ orc(Seq(path): _*)
+ }
+
+ /**
+ * Loads ORC files and returns the result as a `DataFrame`.
+ *
+ * ORC-specific option(s) for reading ORC files can be found in Data
+ * Source Option in the version you use.
+ *
+ * @param paths
+ * input paths
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def orc(paths: String*): DataFrame = format("orc").load(paths: _*)
+
+ /**
+ * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch
+ * reading and the returned DataFrame is the batch scan query plan of this table. If it's a
+ * view, the returned DataFrame is simply the query plan of the view, which can either be a
+ * batch or streaming query plan.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table or view. If a database is
+ * specified, it identifies the table/view from the database. Otherwise, it first attempts to
+ * find a temporary view with the given name and then match the table/view from the current
+ * database. Note that, the global temporary view database is also valid here.
+ * @since 3.4.0
+ */
+ def table(tableName: String): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ builder.getReadBuilder.getNamedTableBuilder
+ .setUnparsedIdentifier(tableName)
+ .putAllOptions(extraOptions.toMap.asJava)
+ }
+ }
+
+ /**
+ * Loads text files and returns a `DataFrame` whose schema starts with a string column named
+ * "value", and followed by partitioned columns if there are any. See the documentation on the
+ * other overloaded `text()` method for more details.
+ *
+ * @since 3.4.0
+ */
+ def text(path: String): DataFrame = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ text(Seq(path): _*)
+ }
+
+ /**
+ * Loads text files and returns a `DataFrame` whose schema starts with a string column named
+ * "value", and followed by partitioned columns if there are any. The text files must be encoded
+ * as UTF-8.
+ *
+ * By default, each line in the text files is a new row in the resulting DataFrame. For example:
+ * {{{
+ * // Scala:
+ * spark.read.text("/path/to/spark/README.md")
+ *
+ * // Java:
+ * spark.read().text("/path/to/spark/README.md")
+ * }}}
+ *
+ * You can find the text-specific options for reading text files in
+ * Data Source Option in the version you use.
+ *
+ * @param paths
+ * input paths
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def text(paths: String*): DataFrame = format("text").load(paths: _*)
+
+ /**
+ * Loads text files and returns a [[Dataset]] of String. See the documentation on the other
+ * overloaded `textFile()` method for more details.
+ * @since 3.4.0
+ */
+ def textFile(path: String): Dataset[String] = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ textFile(Seq(path): _*)
+ }
+
+ /**
+ * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset
+ * contains a single string column named "value". The text files must be encoded as UTF-8.
+ *
+ * If the directory structure of the text files contains partitioning information, those are
+ * ignored in the resulting Dataset. To include partitioning information as columns, use `text`.
+ *
+ * By default, each line in the text files is a new row in the resulting DataFrame. For example:
+ * {{{
+ * // Scala:
+ * spark.read.textFile("/path/to/spark/README.md")
+ *
+ * // Java:
+ * spark.read().textFile("/path/to/spark/README.md")
+ * }}}
+ *
+ * You can set the text-specific options as specified in `DataFrameReader.text`.
+ *
+ * @param paths
+ * input path
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def textFile(paths: String*): Dataset[String] = {
+ assertNoSpecifiedSchema("textFile")
+ text(paths: _*).select("value").as(StringEncoder)
+ }
+
+ private def assertSourceFormatSpecified(): Unit = {
+ if (source == null) {
+ throw new IllegalArgumentException("The source format must be specified.")
+ }
+ }
+
+ private def parse(ds: Dataset[String], format: ParseFormat): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ val parseBuilder = builder.getParseBuilder
+ .setInput(ds.plan.getRoot)
+ .setFormat(format)
+ userSpecifiedSchema.foreach(schema =>
+ parseBuilder.setSchema(DataTypeProtoConverter.toConnectProtoType(schema)))
+ extraOptions.foreach { case (k, v) =>
+ parseBuilder.putOptions(k, v)
+ }
+ }
+ }
+
+ /**
+ * A convenient function for schema validation in APIs.
+ */
+ private def assertNoSpecifiedSchema(operation: String): Unit = {
+ if (userSpecifiedSchema.nonEmpty) {
+ throw QueryCompilationErrors.userSpecifiedSchemaUnsupportedError(operation)
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////////////////
+ // Builder pattern config options
+ ///////////////////////////////////////////////////////////////////////////////////////
+
+ private var source: String = _
+
+ private var userSpecifiedSchema: Option[StructType] = None
+
+ private var extraOptions = CaseInsensitiveMap[String](Map.empty)
+
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
new file mode 100644
index 0000000000000..0d4372b8738ee
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -0,0 +1,592 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.{lang => jl, util => ju}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto.{Relation, StatSampleBy}
+import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder}
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.util.sketch.CountMinSketch
+
+/**
+ * Statistic functions for `DataFrame`s.
+ *
+ * @since 3.4.0
+ */
+final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, root: Relation) {
+
+ /**
+ * Calculates the approximate quantiles of a numerical column of a DataFrame.
+ *
+ * The result of this algorithm has the following deterministic bound: If the DataFrame has N
+ * elements and if we request the quantile at probability `p` up to error `err`, then the
+ * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is
+ * close to (p * N). More precisely,
+ *
+ * {{{
+ * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N)
+ * }}}
+ *
+ * This method implements a variation of the Greenwald-Khanna algorithm (with some speed
+ * optimizations). The algorithm was first present in Space-efficient Online Computation of Quantile
+ * Summaries by Greenwald and Khanna.
+ *
+ * @param col
+ * the name of the numerical column
+ * @param probabilities
+ * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the
+ * minimum, 0.5 is the median, 1 is the maximum.
+ * @param relativeError
+ * The relative target precision to achieve (greater than or equal to 0). If set to zero, the
+ * exact quantiles are computed, which could be very expensive. Note that values greater than
+ * 1 are accepted but give the same result as 1.
+ * @return
+ * the approximate quantiles at the given probabilities
+ *
+ * @note
+ * null and NaN values will be removed from the numerical column before calculation. If the
+ * dataframe is empty or the column only contains null or NaN, an empty array is returned.
+ *
+ * @since 3.4.0
+ */
+ def approxQuantile(
+ col: String,
+ probabilities: Array[Double],
+ relativeError: Double): Array[Double] = {
+ approxQuantile(Array(col), probabilities, relativeError).head
+ }
+
+ /**
+ * Calculates the approximate quantiles of numerical columns of a DataFrame.
+ * @see
+ * `approxQuantile(col:Str* approxQuantile)` for detailed description.
+ *
+ * @param cols
+ * the names of the numerical columns
+ * @param probabilities
+ * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the
+ * minimum, 0.5 is the median, 1 is the maximum.
+ * @param relativeError
+ * The relative target precision to achieve (greater than or equal to 0). If set to zero, the
+ * exact quantiles are computed, which could be very expensive. Note that values greater than
+ * 1 are accepted but give the same result as 1.
+ * @return
+ * the approximate quantiles at the given probabilities of each column
+ *
+ * @note
+ * null and NaN values will be ignored in numerical columns before calculation. For columns
+ * only containing null or NaN values, an empty array is returned.
+ *
+ * @since 3.4.0
+ */
+ def approxQuantile(
+ cols: Array[String],
+ probabilities: Array[Double],
+ relativeError: Double): Array[Array[Double]] = {
+ require(
+ probabilities.forall(p => p >= 0.0 && p <= 1.0),
+ "percentile should be in the range [0.0, 1.0]")
+ require(relativeError >= 0, s"Relative Error must be non-negative but got $relativeError")
+ sparkSession
+ .newDataset(approxQuantileResultEncoder) { builder =>
+ val approxQuantileBuilder = builder.getApproxQuantileBuilder
+ .setInput(root)
+ .setRelativeError(relativeError)
+ cols.foreach(approxQuantileBuilder.addCols)
+ probabilities.foreach(approxQuantileBuilder.addProbabilities)
+ }
+ .head()
+ }
+
+ /**
+ * Calculate the sample covariance of two numerical columns of a DataFrame.
+ * @param col1
+ * the name of the first column
+ * @param col2
+ * the name of the second column
+ * @return
+ * the covariance of the two columns.
+ *
+ * {{{
+ * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
+ * .withColumn("rand2", rand(seed=27))
+ * df.stat.cov("rand1", "rand2")
+ * res1: Double = 0.065...
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def cov(col1: String, col2: String): Double = {
+ sparkSession
+ .newDataset(PrimitiveDoubleEncoder) { builder =>
+ builder.getCovBuilder.setInput(root).setCol1(col1).setCol2(col2)
+ }
+ .head()
+ }
+
+ /**
+ * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
+ * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
+ * MLlib's Statistics.
+ *
+ * @param col1
+ * the name of the column
+ * @param col2
+ * the name of the column to calculate the correlation against
+ * @return
+ * The Pearson Correlation Coefficient as a Double.
+ *
+ * {{{
+ * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
+ * .withColumn("rand2", rand(seed=27))
+ * df.stat.corr("rand1", "rand2")
+ * res1: Double = 0.613...
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def corr(col1: String, col2: String, method: String): Double = {
+ require(
+ method == "pearson",
+ "Currently only the calculation of the Pearson Correlation " +
+ "coefficient is supported.")
+ sparkSession
+ .newDataset(PrimitiveDoubleEncoder) { builder =>
+ builder.getCorrBuilder.setInput(root).setCol1(col1).setCol2(col2)
+ }
+ .head()
+ }
+
+ /**
+ * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame.
+ *
+ * @param col1
+ * the name of the column
+ * @param col2
+ * the name of the column to calculate the correlation against
+ * @return
+ * The Pearson Correlation Coefficient as a Double.
+ *
+ * {{{
+ * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
+ * .withColumn("rand2", rand(seed=27))
+ * df.stat.corr("rand1", "rand2", "pearson")
+ * res1: Double = 0.613...
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def corr(col1: String, col2: String): Double = {
+ corr(col1, col2, "pearson")
+ }
+
+ /**
+ * Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
+ * The first column of each row will be the distinct values of `col1` and the column names will
+ * be the distinct values of `col2`. The name of the first column will be `col1_col2`. Counts
+ * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts.
+ * Null elements will be replaced by "null", and back ticks will be dropped from elements if
+ * they exist.
+ *
+ * @param col1
+ * The name of the first column. Distinct items will make the first item of each row.
+ * @param col2
+ * The name of the second column. Distinct items will make the column names of the DataFrame.
+ * @return
+ * A DataFrame containing for the contingency table.
+ *
+ * {{{
+ * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3)))
+ * .toDF("key", "value")
+ * val ct = df.stat.crosstab("key", "value")
+ * ct.show()
+ * +---------+---+---+---+
+ * |key_value| 1| 2| 3|
+ * +---------+---+---+---+
+ * | 2| 2| 0| 1|
+ * | 1| 1| 1| 0|
+ * | 3| 0| 1| 1|
+ * +---------+---+---+---+
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def crosstab(col1: String, col2: String): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ builder.getCrosstabBuilder.setInput(root).setCol1(col1).setCol2(col2)
+ }
+ }
+
+ /**
+ * Finding frequent items for columns, possibly with false positives. Using the frequent element
+ * count algorithm described in here,
+ * proposed by Karp, Schenker, and Papadimitriou. The `support` should be greater than 1e-4.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting `DataFrame`.
+ *
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @param support
+ * The minimum frequency for an item to be considered `frequent`. Should be greater than 1e-4.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
+ *
+ * {{{
+ * val rows = Seq.tabulate(100) { i =>
+ * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0)
+ * }
+ * val df = spark.createDataFrame(rows).toDF("a", "b")
+ * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns
+ * // "a" and "b"
+ * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4)
+ * freqSingles.show()
+ * +-----------+-------------+
+ * |a_freqItems| b_freqItems|
+ * +-----------+-------------+
+ * | [1, 99]|[-1.0, -99.0]|
+ * +-----------+-------------+
+ * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b"
+ * val pairDf = df.select(struct("a", "b").as("a-b"))
+ * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1)
+ * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show()
+ * +----------+
+ * | freq_ab|
+ * +----------+
+ * | [1,-1.0]|
+ * | ... |
+ * +----------+
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def freqItems(cols: Array[String], support: Double): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ val freqItemsBuilder = builder.getFreqItemsBuilder.setInput(root).setSupport(support)
+ cols.foreach(freqItemsBuilder.addCols)
+ }
+ }
+
+ /**
+ * Finding frequent items for columns, possibly with false positives. Using the frequent element
+ * count algorithm described in here,
+ * proposed by Karp, Schenker, and Papadimitriou. Uses a `default` support of 1%.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting `DataFrame`.
+ *
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
+ *
+ * @since 3.4.0
+ */
+ def freqItems(cols: Array[String]): DataFrame = {
+ freqItems(cols, 0.01)
+ }
+
+ /**
+ * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the
+ * frequent element count algorithm described in here, proposed by Karp, Schenker, and
+ * Papadimitriou.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting `DataFrame`.
+ *
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
+ *
+ * {{{
+ * val rows = Seq.tabulate(100) { i =>
+ * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0)
+ * }
+ * val df = spark.createDataFrame(rows).toDF("a", "b")
+ * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns
+ * // "a" and "b"
+ * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4)
+ * freqSingles.show()
+ * +-----------+-------------+
+ * |a_freqItems| b_freqItems|
+ * +-----------+-------------+
+ * | [1, 99]|[-1.0, -99.0]|
+ * +-----------+-------------+
+ * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b"
+ * val pairDf = df.select(struct("a", "b").as("a-b"))
+ * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1)
+ * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show()
+ * +----------+
+ * | freq_ab|
+ * +----------+
+ * | [1,-1.0]|
+ * | ... |
+ * +----------+
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def freqItems(cols: Seq[String], support: Double): DataFrame = {
+ freqItems(cols.toArray, support)
+ }
+
+ /**
+ * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the
+ * frequent element count algorithm described in here, proposed by Karp, Schenker, and
+ * Papadimitriou. Uses a `default` support of 1%.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting `DataFrame`.
+ *
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
+ *
+ * @since 3.4.0
+ */
+ def freqItems(cols: Seq[String]): DataFrame = {
+ freqItems(cols.toArray, 0.01)
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
+ *
+ * {{{
+ * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2),
+ * (3, 3))).toDF("key", "value")
+ * val fractions = Map(1 -> 1.0, 3 -> 0.5)
+ * df.stat.sampleBy("key", fractions, 36L).show()
+ * +---+-----+
+ * |key|value|
+ * +---+-----+
+ * | 1| 1|
+ * | 1| 2|
+ * | 3| 2|
+ * +---+-----+
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = {
+ sampleBy(Column(col), fractions, seed)
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
+ *
+ * @since 3.4.0
+ */
+ def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
+ sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
+ *
+ * The stratified sample can be performed over multiple columns:
+ * {{{
+ * import org.apache.spark.sql.Row
+ * import org.apache.spark.sql.functions.struct
+ *
+ * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17),
+ * ("Alice", 10))).toDF("name", "age")
+ * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0)
+ * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show()
+ * +-----+---+
+ * | name|age|
+ * +-----+---+
+ * | Nico| 8|
+ * |Alice| 10|
+ * +-----+---+
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = {
+ require(
+ fractions.values.forall(p => p >= 0.0 && p <= 1.0),
+ s"Fractions must be in [0, 1], but got $fractions.")
+ sparkSession.newDataFrame { builder =>
+ val sampleByBuilder = builder.getSampleByBuilder
+ .setInput(root)
+ .setCol(col.expr)
+ .setSeed(seed)
+ fractions.foreach { case (k, v) =>
+ sampleByBuilder.addFractions(
+ StatSampleBy.Fraction
+ .newBuilder()
+ .setStratum(lit(k).expr.getLiteral)
+ .setFraction(v))
+ }
+ }
+ }
+
+ /**
+ * (Java-specific) Returns a stratified sample without replacement based on the fraction given
+ * on each stratum.
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
+ *
+ * @since 3.4.0
+ */
+ def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
+ sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param colName
+ * name of the column over which the sketch is built
+ * @param depth
+ * depth of the sketch
+ * @param width
+ * width of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
+ * @since 3.4.0
+ */
+ def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = {
+ countMinSketch(Column(colName), depth, width, seed)
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param colName
+ * name of the column over which the sketch is built
+ * @param eps
+ * relative error of the sketch
+ * @param confidence
+ * confidence of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
+ * @since 3.4.0
+ */
+ def countMinSketch(
+ colName: String,
+ eps: Double,
+ confidence: Double,
+ seed: Int): CountMinSketch = {
+ countMinSketch(Column(colName), eps, confidence, seed)
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param col
+ * the column over which the sketch is built
+ * @param depth
+ * depth of the sketch
+ * @param width
+ * width of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
+ * @since 3.4.0
+ */
+ def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = {
+ countMinSketch(col, eps = 2.0 / width, confidence = 1 - 1 / Math.pow(2, depth), seed)
+ }
+
+ /**
+ * Builds a Count-min Sketch over a specified column.
+ *
+ * @param col
+ * the column over which the sketch is built
+ * @param eps
+ * relative error of the sketch
+ * @param confidence
+ * confidence of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
+ * @since 3.4.0
+ */
+ def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
+ val agg = Column.fn("count_min_sketch", col, lit(eps), lit(confidence), lit(seed))
+ val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
+ builder.getProjectBuilder
+ .setInput(root)
+ .addExpressions(agg.expr)
+ }
+ CountMinSketch.readFrom(ds.head())
+ }
+}
+
+private object DataFrameStatFunctions {
+ private val approxQuantileResultEncoder: ArrayEncoder[Array[Double]] =
+ ArrayEncoder(ArrayEncoder(PrimitiveDoubleEncoder, containsNull = false), containsNull = false)
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
new file mode 100644
index 0000000000000..b9d1fefb105e8
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -0,0 +1,490 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.{Locale, Properties}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.Stable
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+
+/**
+ * Interface used to write a [[Dataset]] to external storage systems (e.g. file systems, key-value
+ * stores, etc). Use `Dataset.write` to access this.
+ *
+ * @since 3.4.0
+ */
+@Stable
+final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) {
+
+ /**
+ * Specifies the behavior when data or table already exists. Options include:
+ *
`SaveMode.Overwrite`: overwrite the existing data.
`SaveMode.Append`: append the
+ * data.
`SaveMode.Ignore`: ignore the operation (i.e. no-op).
+ *
`SaveMode.ErrorIfExists`: throw an exception at runtime.
The default
+ * option is `ErrorIfExists`.
+ *
+ * @since 3.4.0
+ */
+ def mode(saveMode: SaveMode): DataFrameWriter[T] = {
+ this.mode = saveMode
+ this
+ }
+
+ /**
+ * Specifies the behavior when data or table already exists. Options include:
+ *
`overwrite`: overwrite the existing data.
`append`: append the data.
+ *
`ignore`: ignore the operation (i.e. no-op).
`error` or `errorifexists`: default
+ * option, throw an exception at runtime.
+ *
+ * @since 3.4.0
+ */
+ def mode(saveMode: String): DataFrameWriter[T] = {
+ saveMode.toLowerCase(Locale.ROOT) match {
+ case "overwrite" => mode(SaveMode.Overwrite)
+ case "append" => mode(SaveMode.Append)
+ case "ignore" => mode(SaveMode.Ignore)
+ case "error" | "errorifexists" | "default" => mode(SaveMode.ErrorIfExists)
+ case _ =>
+ throw new IllegalArgumentException(s"Unknown save mode: $saveMode. Accepted " +
+ "save modes are 'overwrite', 'append', 'ignore', 'error', 'errorifexists', 'default'.")
+ }
+ }
+
+ /**
+ * Specifies the underlying output data source. Built-in options include "parquet", "json", etc.
+ *
+ * @since 3.4.0
+ */
+ def format(source: String): DataFrameWriter[T] = {
+ this.source = Some(source)
+ this
+ }
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: String): DataFrameWriter[T] = {
+ this.extraOptions = this.extraOptions + (key -> value)
+ this
+ }
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, value.toString)
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Long): DataFrameWriter[T] = option(key, value.toString)
+
+ /**
+ * Adds an output option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString)
+
+ /**
+ * (Scala-specific) Adds output options for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = {
+ this.extraOptions ++= options
+ this
+ }
+
+ /**
+ * Adds output options for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 3.4.0
+ */
+ def options(options: java.util.Map[String, String]): DataFrameWriter[T] = {
+ this.options(options.asScala)
+ this
+ }
+
+ /**
+ * Partitions the output by the given columns on the file system. If specified, the output is
+ * laid out on the file system similar to Hive's partitioning scheme. As an example, when we
+ * partition a dataset by year and then month, the directory layout would look like:
+ *
year=2016/month=01/
year=2016/month=02/
+ *
+ * Partitioning is one of the most widely used techniques to optimize physical data layout. It
+ * provides a coarse-grained index for skipping unnecessary data reads when queries have
+ * predicates on the partitioned columns. In order for partitioning to work well, the number of
+ * distinct values in each column should typically be less than tens of thousands.
+ *
+ * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark
+ * 2.1.0.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def partitionBy(colNames: String*): DataFrameWriter[T] = {
+ this.partitioningColumns = Option(colNames)
+ this
+ }
+
+ /**
+ * Buckets the output by the given columns. If specified, the output is laid out on the file
+ * system similar to Hive's bucketing scheme, but with a different bucket hash function and is
+ * not compatible with Hive's bucketing.
+ *
+ * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark
+ * 2.1.0.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter[T] = {
+ require(numBuckets > 0, "The numBuckets should be > 0.")
+ this.numBuckets = Option(numBuckets)
+ this.bucketColumnNames = Option(colName +: colNames)
+ this
+ }
+
+ /**
+ * Sorts the output in each bucket by the given columns.
+ *
+ * This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark
+ * 2.1.0.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = {
+ this.sortColumnNames = Option(colName +: colNames)
+ this
+ }
+
+ /**
+ * Saves the content of the `DataFrame` at the specified path.
+ *
+ * @since 3.4.0
+ */
+ def save(path: String): Unit = {
+ saveInternal(Some(path))
+ }
+
+ /**
+ * Saves the content of the `DataFrame` as the specified table.
+ *
+ * @since 3.4.0
+ */
+ def save(): Unit = saveInternal(None)
+
+ private def saveInternal(path: Option[String]): Unit = {
+ executeWriteOperation(builder => path.foreach(builder.setPath))
+ }
+
+ private def executeWriteOperation(f: proto.WriteOperation.Builder => Unit): Unit = {
+ val builder = proto.WriteOperation.newBuilder()
+
+ builder.setInput(ds.plan.getRoot)
+
+ // Set path or table
+ f(builder)
+
+ // Cannot both be set
+ require(!(builder.hasPath && builder.hasTable))
+
+ builder.setMode(mode match {
+ case SaveMode.Append => proto.WriteOperation.SaveMode.SAVE_MODE_APPEND
+ case SaveMode.Overwrite => proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE
+ case SaveMode.Ignore => proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE
+ case SaveMode.ErrorIfExists => proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS
+ })
+
+ source.foreach(builder.setSource)
+ sortColumnNames.foreach(names => builder.addAllSortColumnNames(names.asJava))
+ partitioningColumns.foreach(cols => builder.addAllPartitioningColumns(cols.asJava))
+
+ numBuckets.foreach(n => {
+ val bucketBuilder = proto.WriteOperation.BucketBy.newBuilder()
+ bucketBuilder.setNumBuckets(n)
+ bucketColumnNames.foreach(names => bucketBuilder.addAllBucketColumnNames(names.asJava))
+ builder.setBucketBy(bucketBuilder)
+ })
+
+ extraOptions.foreach { case (k, v) =>
+ builder.putOptions(k, v)
+ }
+
+ ds.sparkSession.execute(proto.Command.newBuilder().setWriteOperation(builder).build())
+ }
+
+ /**
+ * Inserts the content of the `DataFrame` to the specified table. It requires that the schema of
+ * the `DataFrame` is the same as the schema of the table.
+ *
+ * @note
+ * Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based
+ * resolution. For example:
+ *
+ * @note
+ * SaveMode.ErrorIfExists and SaveMode.Ignore behave as SaveMode.Append in `insertInto` as
+ * `insertInto` is not a table creating operation.
+ *
+ * {{{
+ * scala> Seq((1, 2)).toDF("i", "j").write.mode("overwrite").saveAsTable("t1")
+ * scala> Seq((3, 4)).toDF("j", "i").write.insertInto("t1")
+ * scala> Seq((5, 6)).toDF("a", "b").write.insertInto("t1")
+ * scala> sql("select * from t1").show
+ * +---+---+
+ * | i| j|
+ * +---+---+
+ * | 5| 6|
+ * | 3| 4|
+ * | 1| 2|
+ * +---+---+
+ * }}}
+ *
+ * Because it inserts data to an existing table, format or options will be ignored.
+ *
+ * @since 3.4.0
+ */
+ def insertInto(tableName: String): Unit = {
+ executeWriteOperation(builder => {
+ builder.setTable(
+ proto.WriteOperation.SaveTable
+ .newBuilder()
+ .setTableName(tableName)
+ .setSaveMethod(
+ proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO))
+ })
+ }
+
+ /**
+ * Saves the content of the `DataFrame` as the specified table.
+ *
+ * In the case the table already exists, behavior of this function depends on the save mode,
+ * specified by the `mode` function (default to throwing an exception). When `mode` is
+ * `Overwrite`, the schema of the `DataFrame` does not need to be the same as that of the
+ * existing table.
+ *
+ * When `mode` is `Append`, if there is an existing table, we will use the format and options of
+ * the existing table. The column order in the schema of the `DataFrame` doesn't need to be same
+ * as that of the existing table. Unlike `insertInto`, `saveAsTable` will use the column names
+ * to find the correct column positions. For example:
+ *
+ * {{{
+ * scala> Seq((1, 2)).toDF("i", "j").write.mode("overwrite").saveAsTable("t1")
+ * scala> Seq((3, 4)).toDF("j", "i").write.mode("append").saveAsTable("t1")
+ * scala> sql("select * from t1").show
+ * +---+---+
+ * | i| j|
+ * +---+---+
+ * | 1| 2|
+ * | 4| 3|
+ * +---+---+
+ * }}}
+ *
+ * In this method, save mode is used to determine the behavior if the data source table exists
+ * in Spark catalog. We will always overwrite the underlying data of data source (e.g. a table
+ * in JDBC data source) if the table doesn't exist in Spark catalog, and will always append to
+ * the underlying data of data source if the table already exists.
+ *
+ * When the DataFrame is created from a non-partitioned `HadoopFsRelation` with a single input
+ * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC
+ * and Parquet), the table is persisted in a Hive compatible format, which means other systems
+ * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL
+ * specific format.
+ *
+ * @since 3.4.0
+ */
+ def saveAsTable(tableName: String): Unit = {
+ executeWriteOperation(builder => {
+ builder.setTable(
+ proto.WriteOperation.SaveTable
+ .newBuilder()
+ .setTableName(tableName)
+ .setSaveMethod(
+ proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE))
+ })
+ }
+
+ /**
+ * Saves the content of the `DataFrame` to an external database table via JDBC. In the case the
+ * table already exists in the external database, behavior of this function depends on the save
+ * mode, specified by the `mode` function (default to throwing an exception).
+ *
+ * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
+ * your external database systems.
+ *
+ * JDBC-specific option and parameter documentation for storing tables via JDBC in
+ * Data Source Option in the version you use.
+ *
+ * @param table
+ * Name of the table in the external database.
+ * @param connectionProperties
+ * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least
+ * a "user" and "password" property should be included. "batchsize" can be used to control the
+ * number of rows per insert. "isolationLevel" can be one of "NONE", "READ_COMMITTED",
+ * "READ_UNCOMMITTED", "REPEATABLE_READ", or "SERIALIZABLE", corresponding to standard
+ * transaction isolation levels defined by JDBC's Connection object, with default of
+ * "READ_UNCOMMITTED".
+ * @since 3.4.0
+ */
+ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
+ // connectionProperties should override settings in extraOptions.
+ this.extraOptions ++= connectionProperties.asScala
+ // explicit url and dbtable should override all
+ this.extraOptions ++= Seq("url" -> url, "dbtable" -> table)
+ format("jdbc").save()
+ }
+
+ /**
+ * Saves the content of the `DataFrame` in JSON format ( JSON
+ * Lines text format or newline-delimited JSON) at the specified path. This is equivalent
+ * to:
+ * {{{
+ * format("json").save(path)
+ * }}}
+ *
+ * You can find the JSON-specific options for writing JSON files in
+ * Data Source Option in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def json(path: String): Unit = {
+ format("json").save(path)
+ }
+
+ /**
+ * Saves the content of the `DataFrame` in Parquet format at the specified path. This is
+ * equivalent to:
+ * {{{
+ * format("parquet").save(path)
+ * }}}
+ *
+ * Parquet-specific option(s) for writing Parquet files can be found in Data
+ * Source Option in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def parquet(path: String): Unit = {
+ format("parquet").save(path)
+ }
+
+ /**
+ * Saves the content of the `DataFrame` in ORC format at the specified path. This is equivalent
+ * to:
+ * {{{
+ * format("orc").save(path)
+ * }}}
+ *
+ * ORC-specific option(s) for writing ORC files can be found in Data
+ * Source Option in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def orc(path: String): Unit = {
+ format("orc").save(path)
+ }
+
+ /**
+ * Saves the content of the `DataFrame` in a text file at the specified path. The DataFrame must
+ * have only one column that is of string type. Each row becomes a new line in the output file.
+ * For example:
+ * {{{
+ * // Scala:
+ * df.write.text("/path/to/output")
+ *
+ * // Java:
+ * df.write().text("/path/to/output")
+ * }}}
+ * The text files will be encoded as UTF-8.
+ *
+ * You can find the text-specific options for writing text files in
+ * Data Source Option in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def text(path: String): Unit = {
+ format("text").save(path)
+ }
+
+ /**
+ * Saves the content of the `DataFrame` in CSV format at the specified path. This is equivalent
+ * to:
+ * {{{
+ * format("csv").save(path)
+ * }}}
+ *
+ * You can find the CSV-specific options for writing CSV files in
+ * Data Source Option in the version you use.
+ *
+ * @since 3.4.0
+ */
+ def csv(path: String): Unit = {
+ format("csv").save(path)
+ }
+
+ ///////////////////////////////////////////////////////////////////////////////////////
+ // Builder pattern config options
+ ///////////////////////////////////////////////////////////////////////////////////////
+
+ private var source: Option[String] = None
+
+ private var mode: SaveMode = SaveMode.ErrorIfExists
+
+ private var extraOptions = CaseInsensitiveMap[String](Map.empty)
+
+ private var partitioningColumns: Option[Seq[String]] = None
+
+ private var bucketColumnNames: Option[Seq[String]] = None
+
+ private var numBuckets: Option[Int] = None
+
+ private var sortColumnNames: Option[Seq[String]] = None
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
new file mode 100644
index 0000000000000..b698e1dfaa1c9
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.connect.proto
+
+/**
+ * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2
+ * API.
+ *
+ * @since 3.4.0
+ */
+@Experimental
+final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
+ extends CreateTableWriter[T] {
+
+ private var provider: Option[String] = None
+
+ private val options = new mutable.HashMap[String, String]()
+
+ private val properties = new mutable.HashMap[String, String]()
+
+ private var partitioning: Option[Seq[proto.Expression]] = None
+
+ private var overwriteCondition: Option[proto.Expression] = None
+
+ override def using(provider: String): CreateTableWriter[T] = {
+ this.provider = Some(provider)
+ this
+ }
+
+ override def option(key: String, value: String): DataFrameWriterV2[T] = {
+ this.options.put(key, value)
+ this
+ }
+
+ override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = {
+ options.foreach { case (key, value) =>
+ this.options.put(key, value)
+ }
+ this
+ }
+
+ override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = {
+ this.options(options.asScala)
+ this
+ }
+
+ override def tableProperty(property: String, value: String): CreateTableWriter[T] = {
+ this.properties.put(property, value)
+ this
+ }
+
+ @scala.annotation.varargs
+ override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = {
+ val asTransforms = (column +: columns).map(_.expr)
+ this.partitioning = Some(asTransforms)
+ this
+ }
+
+ override def create(): Unit = {
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE)
+ }
+
+ override def replace(): Unit = {
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_REPLACE)
+ }
+
+ override def createOrReplace(): Unit = {
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE)
+ }
+
+ /**
+ * Append the contents of the data frame to the output table.
+ *
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
+ * validated to ensure it is compatible with the existing table.
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+ * If the table does not exist
+ */
+ def append(): Unit = {
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_APPEND)
+ }
+
+ /**
+ * Overwrite rows matching the given filter condition with the contents of the data frame in the
+ * output table.
+ *
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
+ * validated to ensure it is compatible with the existing table.
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+ * If the table does not exist
+ */
+ def overwrite(condition: Column): Unit = {
+ overwriteCondition = Some(condition.expr)
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE)
+ }
+
+ /**
+ * Overwrite all partition for which the data frame contains at least one row with the contents
+ * of the data frame in the output table.
+ *
+ * This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces
+ * partitions dynamically depending on the contents of the data frame.
+ *
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
+ * validated to ensure it is compatible with the existing table.
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+ * If the table does not exist
+ */
+ def overwritePartitions(): Unit = {
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS)
+ }
+
+ private def executeWriteOperation(mode: proto.WriteOperationV2.Mode): Unit = {
+ val builder = proto.WriteOperationV2.newBuilder()
+
+ builder.setInput(ds.plan.getRoot)
+ builder.setTableName(table)
+ provider.foreach(builder.setProvider)
+
+ partitioning.foreach(columns => builder.addAllPartitioningColumns(columns.asJava))
+
+ options.foreach { case (k, v) =>
+ builder.putOptions(k, v)
+ }
+ properties.foreach { case (k, v) =>
+ builder.putTableProperties(k, v)
+ }
+
+ builder.setMode(mode)
+
+ overwriteCondition.foreach(builder.setOverwriteCondition)
+
+ ds.sparkSession.execute(proto.Command.newBuilder().setWriteOperationV2(builder).build())
+ }
+}
+
+/**
+ * Configuration methods common to create/replace operations and insert/overwrite operations.
+ * @tparam R
+ * builder type to return
+ * @since 3.4.0
+ */
+trait WriteConfigMethods[R] {
+
+ /**
+ * Add a write option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: String): R
+
+ /**
+ * Add a boolean output option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Boolean): R = option(key, value.toString)
+
+ /**
+ * Add a long output option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Long): R = option(key, value.toString)
+
+ /**
+ * Add a double output option.
+ *
+ * @since 3.4.0
+ */
+ def option(key: String, value: Double): R = option(key, value.toString)
+
+ /**
+ * Add write options from a Scala Map.
+ *
+ * @since 3.4.0
+ */
+ def options(options: scala.collection.Map[String, String]): R
+
+ /**
+ * Add write options from a Java Map.
+ *
+ * @since 3.4.0
+ */
+ def options(options: java.util.Map[String, String]): R
+}
+
+/**
+ * Trait to restrict calls to create and replace operations.
+ *
+ * @since 3.4.0
+ */
+trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
+
+ /**
+ * Create a new table from the contents of the data frame.
+ *
+ * The new table's schema, partition layout, properties, and other configuration will be based
+ * on the configuration set on this writer.
+ *
+ * If the output table exists, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException]].
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+ * If the table already exists
+ */
+ def create(): Unit
+
+ /**
+ * Replace an existing table with the contents of the data frame.
+ *
+ * The existing table's schema, partition layout, properties, and other configuration will be
+ * replaced with the contents of the data frame and the configuration set on this writer.
+ *
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException]].
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
+ * If the table does not exist
+ */
+ def replace(): Unit
+
+ /**
+ * Create a new table or replace an existing table with the contents of the data frame.
+ *
+ * The output table's schema, partition layout, properties, and other configuration will be
+ * based on the contents of the data frame and the configuration set on this writer. If the
+ * table exists, its configuration and data will be replaced.
+ */
+ def createOrReplace(): Unit
+
+ /**
+ * Partition the output table created by `create`, `createOrReplace`, or `replace` using the
+ * given columns or transforms.
+ *
+ * When specified, the table data will be stored by these values for efficient reads.
+ *
+ * For example, when a table is partitioned by day, it may be stored in a directory layout like:
+ *
`table/day=2019-06-01/`
`table/day=2019-06-02/`
+ *
+ * Partitioning is one of the most widely used techniques to optimize physical data layout. It
+ * provides a coarse-grained index for skipping unnecessary data reads when queries have
+ * predicates on the partitioned columns. In order for partitioning to work well, the number of
+ * distinct values in each column should typically be less than tens of thousands.
+ *
+ * @since 3.4.0
+ */
+ def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T]
+
+ /**
+ * Specifies a provider for the underlying output data source. Spark's default catalog supports
+ * "parquet", "json", etc.
+ *
+ * @since 3.4.0
+ */
+ def using(provider: String): CreateTableWriter[T]
+
+ /**
+ * Add a table property.
+ */
+ def tableProperty(property: String, value: String): CreateTableWriter[T]
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
new file mode 100644
index 0000000000000..ca90afa14cf3f
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -0,0 +1,2870 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import java.util.{Collections, Locale}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.util.control.NonFatal
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveLongEncoder, StringEncoder, UnboundRowEncoder}
+import org.apache.spark.sql.catalyst.expressions.RowOrdering
+import org.apache.spark.sql.connect.client.SparkResult
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter}
+import org.apache.spark.sql.functions.{struct, to_json}
+import org.apache.spark.sql.types.{Metadata, StructType}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+/**
+ * A Dataset is a strongly typed collection of domain-specific objects that can be transformed in
+ * parallel using functional or relational operations. Each Dataset also has an untyped view
+ * called a `DataFrame`, which is a Dataset of [[Row]].
+ *
+ * Operations available on Datasets are divided into transformations and actions. Transformations
+ * are the ones that produce new Datasets, and actions are the ones that trigger computation and
+ * return results. Example transformations include map, filter, select, and aggregate (`groupBy`).
+ * Example actions count, show, or writing data out to file systems.
+ *
+ * Datasets are "lazy", i.e. computations are only triggered when an action is invoked.
+ * Internally, a Dataset represents a logical plan that describes the computation required to
+ * produce the data. When an action is invoked, Spark's query optimizer optimizes the logical plan
+ * and generates a physical plan for efficient execution in a parallel and distributed manner. To
+ * explore the logical plan as well as optimized physical plan, use the `explain` function.
+ *
+ * To efficiently support domain-specific objects, an [[Encoder]] is required. The encoder maps
+ * the domain specific type `T` to Spark's internal type system. For example, given a class
+ * `Person` with two fields, `name` (string) and `age` (int), an encoder is used to tell Spark to
+ * generate code at runtime to serialize the `Person` object into a binary structure. This binary
+ * structure often has much lower memory footprint as well as are optimized for efficiency in data
+ * processing (e.g. in a columnar format). To understand the internal binary representation for
+ * data, use the `schema` function.
+ *
+ * There are typically two ways to create a Dataset. The most common way is by pointing Spark to
+ * some files on storage systems, using the `read` function available on a `SparkSession`.
+ * {{{
+ * val people = spark.read.parquet("...").as[Person] // Scala
+ * Dataset people = spark.read().parquet("...").as(Encoders.bean(Person.class)); // Java
+ * }}}
+ *
+ * Datasets can also be created through transformations available on existing Datasets. For
+ * example, the following creates a new Dataset by applying a filter on the existing one:
+ * {{{
+ * val names = people.map(_.name) // in Scala; names is a Dataset[String]
+ * Dataset names = people.map((Person p) -> p.name, Encoders.STRING));
+ * }}}
+ *
+ * Dataset operations can also be untyped, through various domain-specific-language (DSL)
+ * functions defined in: Dataset (this class), [[Column]], and [[functions]]. These operations are
+ * very similar to the operations available in the data frame abstraction in R or Python.
+ *
+ * To select a column from the Dataset, use `apply` method in Scala and `col` in Java.
+ * {{{
+ * val ageCol = people("age") // in Scala
+ * Column ageCol = people.col("age"); // in Java
+ * }}}
+ *
+ * Note that the [[Column]] type can also be manipulated through its various functions.
+ * {{{
+ * // The following creates a new column that increases everybody's age by 10.
+ * people("age") + 10 // in Scala
+ * people.col("age").plus(10); // in Java
+ * }}}
+ *
+ * A more concrete example in Scala:
+ * {{{
+ * // To create Dataset[Row] using SparkSession
+ * val people = spark.read.parquet("...")
+ * val department = spark.read.parquet("...")
+ *
+ * people.filter("age > 30")
+ * .join(department, people("deptId") === department("id"))
+ * .groupBy(department("name"), people("gender"))
+ * .agg(avg(people("salary")), max(people("age")))
+ * }}}
+ *
+ * and in Java:
+ * {{{
+ * // To create Dataset using SparkSession
+ * Dataset people = spark.read().parquet("...");
+ * Dataset department = spark.read().parquet("...");
+ *
+ * people.filter(people.col("age").gt(30))
+ * .join(department, people.col("deptId").equalTo(department.col("id")))
+ * .groupBy(department.col("name"), people.col("gender"))
+ * .agg(avg(people.col("salary")), max(people.col("age")));
+ * }}}
+ *
+ * @groupname basic Basic Dataset functions
+ * @groupname action Actions
+ * @groupname untypedrel Untyped transformations
+ * @groupname typedrel Typed transformations
+ *
+ * @since 3.4.0
+ */
+class Dataset[T] private[sql] (
+ val sparkSession: SparkSession,
+ @DeveloperApi val plan: proto.Plan,
+ val encoder: AgnosticEncoder[T])
+ extends Serializable {
+ // Make sure we don't forget to set plan id.
+ assert(plan.getRoot.getCommon.hasPlanId)
+
+ override def toString: String = {
+ try {
+ val builder = new mutable.StringBuilder
+ val fields = schema.take(2).map { f =>
+ s"${f.name}: ${f.dataType.simpleString(2)}"
+ }
+ builder.append("[")
+ builder.append(fields.mkString(", "))
+ if (schema.length > 2) {
+ if (schema.length - fields.size == 1) {
+ builder.append(" ... 1 more field")
+ } else {
+ builder.append(" ... " + (schema.length - 2) + " more fields")
+ }
+ }
+ builder.append("]").toString()
+ } catch {
+ case NonFatal(e) =>
+ s"Invalid Dataframe; ${e.getMessage}"
+ }
+ }
+
+ /**
+ * Converts this strongly typed collection of data to generic Dataframe. In contrast to the
+ * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]]
+ * objects that allow fields to be accessed by ordinal or name.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def toDF(): DataFrame = new Dataset(sparkSession, plan, UnboundRowEncoder)
+
+ /**
+ * Returns a new Dataset where each record has been mapped on to the specified type. The method
+ * used to map columns depend on the type of `U`:
When `U` is a class, fields for the
+ * class will be mapped to columns of the same name (case sensitivity is determined by
+ * `spark.sql.caseSensitive`).
When `U` is a tuple, the columns will be mapped by
+ * ordinal (i.e. the first column will be assigned to `_1`).
When `U` is a primitive
+ * type (i.e. String, Int, etc), then the first column of the `DataFrame` will be used.
+ *
+ *
+ * If the schema of the Dataset does not match the desired `U` type, you can use `select` along
+ * with `alias` or `as` to rearrange or rename as required.
+ *
+ * Note that `as[]` only changes the view of the data that is passed into typed operations, such
+ * as `map()`, and does not eagerly project away any columns that are not present in the
+ * specified class.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def as[U: Encoder]: Dataset[U] = {
+ val encoder = implicitly[Encoder[U]].asInstanceOf[AgnosticEncoder[U]]
+ // We should add some validation/coercion here. We cannot use `to`
+ // because that does not work with positional arguments.
+ new Dataset[U](sparkSession, plan, encoder)
+ }
+
+ /**
+ * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed.
+ * This can be quite convenient in conversion from an RDD of tuples into a `DataFrame` with
+ * meaningful names. For example:
+ * {{{
+ * val rdd: RDD[(Int, String)] = ...
+ * rdd.toDF() // this implicit conversion creates a DataFrame with column name `_1` and `_2`
+ * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name"
+ * }}}
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def toDF(colNames: String*): DataFrame = sparkSession.newDataFrame { builder =>
+ builder.getToDfBuilder
+ .setInput(plan.getRoot)
+ .addAllColumnNames(colNames.asJava)
+ }
+
+ /**
+ * Returns a new DataFrame where each row is reconciled to match the specified schema. Spark
+ * will:
Reorder columns and/or inner fields by name to match the specified
+ * schema.
Project away columns and/or inner fields that are not needed by the
+ * specified schema. Missing columns and/or inner fields (present in the specified schema but
+ * not input DataFrame) lead to failures.
Cast the columns and/or inner fields to match
+ * the data types in the specified schema, if the types are compatible, e.g., numeric to numeric
+ * (error if overflows), but not string to int.
Carry over the metadata from the
+ * specified schema, while the columns and/or inner fields still keep their own metadata if not
+ * overwritten by the specified schema.
Fail if the nullability is not compatible. For
+ * example, the column and/or inner field is nullable but the specified schema requires them to
+ * be not nullable.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def to(schema: StructType): DataFrame = sparkSession.newDataFrame { builder =>
+ builder.getToSchemaBuilder
+ .setInput(plan.getRoot)
+ .setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
+ }
+
+ /**
+ * Returns the schema of this Dataset.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def schema: StructType = {
+ if (encoder == UnboundRowEncoder) {
+ DataTypeProtoConverter
+ .toCatalystType(
+ sparkSession
+ .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA)
+ .getSchema
+ .getSchema)
+ .asInstanceOf[StructType]
+ } else {
+ encoder.schema
+ }
+ }
+
+ /**
+ * Prints the schema to the console in a nice tree format.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def printSchema(): Unit = printSchema(Int.MaxValue)
+
+ // scalastyle:off println
+ /**
+ * Prints the schema up to the given level to the console in a nice tree format.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def printSchema(level: Int): Unit = println(schema.treeString(level))
+ // scalastyle:on println
+
+ /**
+ * Prints the plans (logical and physical) with a format specified by a given explain mode.
+ *
+ * @param mode
+ * specifies the expected output format of plans.
`simple` Print only a physical
+ * plan.
`extended`: Print both logical and physical plans.
`codegen`: Print
+ * a physical plan and generated codes if they are available.
`cost`: Print a logical
+ * plan and statistics if they are available.
`formatted`: Split explain output into
+ * two sections: a physical plan outline and node details.
+ * @group basic
+ * @since 3.4.0
+ */
+ def explain(mode: String): Unit = {
+ val protoMode = mode.trim.toLowerCase(Locale.ROOT) match {
+ case "simple" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE
+ case "extended" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_EXTENDED
+ case "codegen" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_CODEGEN
+ case "cost" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_COST
+ case "formatted" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_FORMATTED
+ case _ => throw new IllegalArgumentException("Unsupported explain mode: " + mode)
+ }
+ explain(protoMode)
+ }
+
+ /**
+ * Prints the plans (logical and physical) to the console for debugging purposes.
+ *
+ * @param extended
+ * default `false`. If `false`, prints only the physical plan.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def explain(extended: Boolean): Unit = {
+ val mode = if (extended) {
+ proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_EXTENDED
+ } else {
+ proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE
+ }
+ explain(mode)
+ }
+
+ /**
+ * Prints the physical plan to the console for debugging purposes.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def explain(): Unit = explain(proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE)
+
+ private def explain(mode: proto.AnalyzePlanRequest.Explain.ExplainMode): Unit = {
+ // scalastyle:off println
+ println(
+ sparkSession
+ .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN, Some(mode))
+ .getExplain
+ .getExplainString)
+ // scalastyle:on println
+ }
+
+ /**
+ * Returns all column names and their data types as an array.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def dtypes: Array[(String, String)] = schema.fields.map { field =>
+ (field.name, field.dataType.toString)
+ }
+
+ /**
+ * Returns all column names as an array.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def columns: Array[String] = schema.fields.map(_.name)
+
+ /**
+ * Returns true if the `collect` and `take` methods can be run locally (without any Spark
+ * executors).
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def isLocal: Boolean = sparkSession
+ .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL)
+ .getIsLocal
+ .getIsLocal
+
+ /**
+ * Returns true if the `Dataset` is empty.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def isEmpty: Boolean = select().limit(1).withResult { result =>
+ result.length == 0
+ }
+
+ /**
+ * Returns true if this Dataset contains one or more sources that continuously return data as it
+ * arrives. A Dataset that reads data from a streaming source must be executed as a
+ * `StreamingQuery` using the `start()` method in `DataStreamWriter`.
+ *
+ * @group streaming
+ * @since 3.4.0
+ */
+ def isStreaming: Boolean = sparkSession
+ .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING)
+ .getIsStreaming
+ .getIsStreaming
+
+ /**
+ * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated,
+ * and all cells will be aligned right. For example:
+ * {{{
+ * year month AVG('Adj Close) MAX('Adj Close)
+ * 1980 12 0.503218 0.595103
+ * 1981 01 0.523289 0.570307
+ * 1982 02 0.436504 0.475256
+ * 1983 03 0.410516 0.442194
+ * 1984 04 0.450090 0.483521
+ * }}}
+ *
+ * @param numRows
+ * Number of rows to show
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ def show(numRows: Int): Unit = show(numRows, truncate = true)
+
+ /**
+ * Displays the top 20 rows of Dataset in a tabular form. Strings more than 20 characters will
+ * be truncated, and all cells will be aligned right.
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ def show(): Unit = show(20)
+
+ /**
+ * Displays the top 20 rows of Dataset in a tabular form.
+ *
+ * @param truncate
+ * Whether truncate long strings. If true, strings more than 20 characters will be truncated
+ * and all cells will be aligned right
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ def show(truncate: Boolean): Unit = show(20, truncate)
+
+ /**
+ * Displays the Dataset in a tabular form. For example:
+ * {{{
+ * year month AVG('Adj Close) MAX('Adj Close)
+ * 1980 12 0.503218 0.595103
+ * 1981 01 0.523289 0.570307
+ * 1982 02 0.436504 0.475256
+ * 1983 03 0.410516 0.442194
+ * 1984 04 0.450090 0.483521
+ * }}}
+ * @param numRows
+ * Number of rows to show
+ * @param truncate
+ * Whether truncate long strings. If true, strings more than 20 characters will be truncated
+ * and all cells will be aligned right
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ // scalastyle:off println
+ def show(numRows: Int, truncate: Boolean): Unit = {
+ val truncateValue = if (truncate) 20 else 0
+ show(numRows, truncateValue, vertical = false)
+ }
+
+ /**
+ * Displays the Dataset in a tabular form. For example:
+ * {{{
+ * year month AVG('Adj Close) MAX('Adj Close)
+ * 1980 12 0.503218 0.595103
+ * 1981 01 0.523289 0.570307
+ * 1982 02 0.436504 0.475256
+ * 1983 03 0.410516 0.442194
+ * 1984 04 0.450090 0.483521
+ * }}}
+ *
+ * @param numRows
+ * Number of rows to show
+ * @param truncate
+ * If set to more than 0, truncates strings to `truncate` characters and all cells will be
+ * aligned right.
+ * @group action
+ * @since 3.4.0
+ */
+ def show(numRows: Int, truncate: Int): Unit = show(numRows, truncate, vertical = false)
+
+ /**
+ * Displays the Dataset in a tabular form. For example:
+ * {{{
+ * year month AVG('Adj Close) MAX('Adj Close)
+ * 1980 12 0.503218 0.595103
+ * 1981 01 0.523289 0.570307
+ * 1982 02 0.436504 0.475256
+ * 1983 03 0.410516 0.442194
+ * 1984 04 0.450090 0.483521
+ * }}}
+ *
+ * If `vertical` enabled, this command prints output rows vertically (one line per column
+ * value)?
+ *
+ * {{{
+ * -RECORD 0-------------------
+ * year | 1980
+ * month | 12
+ * AVG('Adj Close) | 0.503218
+ * AVG('Adj Close) | 0.595103
+ * -RECORD 1-------------------
+ * year | 1981
+ * month | 01
+ * AVG('Adj Close) | 0.523289
+ * AVG('Adj Close) | 0.570307
+ * -RECORD 2-------------------
+ * year | 1982
+ * month | 02
+ * AVG('Adj Close) | 0.436504
+ * AVG('Adj Close) | 0.475256
+ * -RECORD 3-------------------
+ * year | 1983
+ * month | 03
+ * AVG('Adj Close) | 0.410516
+ * AVG('Adj Close) | 0.442194
+ * -RECORD 4-------------------
+ * year | 1984
+ * month | 04
+ * AVG('Adj Close) | 0.450090
+ * AVG('Adj Close) | 0.483521
+ * }}}
+ *
+ * @param numRows
+ * Number of rows to show
+ * @param truncate
+ * If set to more than 0, truncates strings to `truncate` characters and all cells will be
+ * aligned right.
+ * @param vertical
+ * If set to true, prints output rows vertically (one line per column value).
+ * @group action
+ * @since 3.4.0
+ */
+ def show(numRows: Int, truncate: Int, vertical: Boolean): Unit = {
+ val df = sparkSession.newDataset(StringEncoder) { builder =>
+ builder.getShowStringBuilder
+ .setInput(plan.getRoot)
+ .setNumRows(numRows)
+ .setTruncate(truncate)
+ .setVertical(vertical)
+ }
+ df.withResult { result =>
+ assert(result.length == 1)
+ assert(result.schema.size == 1)
+ // scalastyle:off println
+ println(result.toArray.head)
+ // scalastyle:on println
+ }
+ }
+
+ /**
+ * Returns a [[DataFrameNaFunctions]] for working with missing data.
+ * {{{
+ * // Dropping rows containing any null values.
+ * ds.na.drop()
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def na: DataFrameNaFunctions = new DataFrameNaFunctions(sparkSession, plan.getRoot)
+
+ /**
+ * Returns a [[DataFrameStatFunctions]] for working statistic functions support.
+ * {{{
+ * // Finding frequent items in column with name 'a'.
+ * ds.stat.freqItems(Seq("a"))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def stat: DataFrameStatFunctions = new DataFrameStatFunctions(sparkSession, plan.getRoot)
+
+ private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ val joinBuilder = builder.getJoinBuilder
+ joinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
+ f(joinBuilder)
+ }
+ }
+
+ private def toJoinType(name: String): proto.Join.JoinType = {
+ name.trim.toLowerCase(Locale.ROOT) match {
+ case "inner" =>
+ proto.Join.JoinType.JOIN_TYPE_INNER
+ case "cross" =>
+ proto.Join.JoinType.JOIN_TYPE_CROSS
+ case "outer" | "full" | "fullouter" | "full_outer" =>
+ proto.Join.JoinType.JOIN_TYPE_FULL_OUTER
+ case "left" | "leftouter" | "left_outer" =>
+ proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER
+ case "right" | "rightouter" | "right_outer" =>
+ proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER
+ case "semi" | "leftsemi" | "left_semi" =>
+ proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI
+ case "anti" | "leftanti" | "left_anti" =>
+ proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI
+ case _ =>
+ throw new IllegalArgumentException(s"Unsupported join type `joinType`.")
+ }
+ }
+
+ /**
+ * Join with another `DataFrame`.
+ *
+ * Behaves as an INNER JOIN and requires a subsequent join predicate.
+ *
+ * @param right
+ * Right side of the join operation.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def join(right: Dataset[_]): DataFrame = buildJoin(right) { builder =>
+ builder.setJoinType(proto.Join.JoinType.JOIN_TYPE_INNER)
+ }
+
+ /**
+ * Inner equi-join with another `DataFrame` using the given column.
+ *
+ * Different from other join functions, the join column will only appear once in the output,
+ * i.e. similar to SQL's `JOIN USING` syntax.
+ *
+ * {{{
+ * // Joining df1 and df2 using the column "user_id"
+ * df1.join(df2, "user_id")
+ * }}}
+ *
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumn
+ * Name of the column to join on. This column must exist on both sides.
+ *
+ * @note
+ * If you perform a self-join using this function without aliasing the input `DataFrame`s, you
+ * will NOT be able to reference any columns after the join, since there is no way to
+ * disambiguate which side of the join you would like to reference.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def join(right: Dataset[_], usingColumn: String): DataFrame = {
+ join(right, Seq(usingColumn))
+ }
+
+ /**
+ * (Java-specific) Inner equi-join with another `DataFrame` using the given columns. See the
+ * Scala-specific overload for more details.
+ *
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumns
+ * Names of the columns to join on. This columns must exist on both sides.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def join(right: Dataset[_], usingColumns: Array[String]): DataFrame = {
+ join(right, usingColumns.toSeq)
+ }
+
+ /**
+ * (Scala-specific) Inner equi-join with another `DataFrame` using the given columns.
+ *
+ * Different from other join functions, the join columns will only appear once in the output,
+ * i.e. similar to SQL's `JOIN USING` syntax.
+ *
+ * {{{
+ * // Joining df1 and df2 using the columns "user_id" and "user_name"
+ * df1.join(df2, Seq("user_id", "user_name"))
+ * }}}
+ *
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumns
+ * Names of the columns to join on. This columns must exist on both sides.
+ *
+ * @note
+ * If you perform a self-join using this function without aliasing the input `DataFrame`s, you
+ * will NOT be able to reference any columns after the join, since there is no way to
+ * disambiguate which side of the join you would like to reference.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame = {
+ join(right, usingColumns, "inner")
+ }
+
+ /**
+ * Equi-join with another `DataFrame` using the given column. A cross join with a predicate is
+ * specified as an inner join. If you would explicitly like to perform a cross join use the
+ * `crossJoin` method.
+ *
+ * Different from other join functions, the join column will only appear once in the output,
+ * i.e. similar to SQL's `JOIN USING` syntax.
+ *
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumn
+ * Name of the column to join on. This column must exist on both sides.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`,
+ * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`,
+ * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`,
+ * `left_anti`.
+ *
+ * @note
+ * If you perform a self-join using this function without aliasing the input `DataFrame`s, you
+ * will NOT be able to reference any columns after the join, since there is no way to
+ * disambiguate which side of the join you would like to reference.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def join(right: Dataset[_], usingColumn: String, joinType: String): DataFrame = {
+ join(right, Seq(usingColumn), joinType)
+ }
+
+ /**
+ * (Java-specific) Equi-join with another `DataFrame` using the given columns. See the
+ * Scala-specific overload for more details.
+ *
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumns
+ * Names of the columns to join on. This columns must exist on both sides.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`,
+ * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`,
+ * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`,
+ * `left_anti`.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def join(right: Dataset[_], usingColumns: Array[String], joinType: String): DataFrame = {
+ join(right, usingColumns.toSeq, joinType)
+ }
+
+ /**
+ * (Scala-specific) Equi-join with another `DataFrame` using the given columns. A cross join
+ * with a predicate is specified as an inner join. If you would explicitly like to perform a
+ * cross join use the `crossJoin` method.
+ *
+ * Different from other join functions, the join columns will only appear once in the output,
+ * i.e. similar to SQL's `JOIN USING` syntax.
+ *
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumns
+ * Names of the columns to join on. This columns must exist on both sides.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`,
+ * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`,
+ * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`,
+ * `left_anti`.
+ *
+ * @note
+ * If you perform a self-join using this function without aliasing the input `DataFrame`s, you
+ * will NOT be able to reference any columns after the join, since there is no way to
+ * disambiguate which side of the join you would like to reference.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = {
+ buildJoin(right) { builder =>
+ builder
+ .setJoinType(toJoinType(joinType))
+ .addAllUsingColumns(usingColumns.asJava)
+ }
+ }
+
+ /**
+ * Inner join with another `DataFrame`, using the given join expression.
+ *
+ * {{{
+ * // The following two are equivalent:
+ * df1.join(df2, $"df1Key" === $"df2Key")
+ * df1.join(df2).where($"df1Key" === $"df2Key")
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def join(right: Dataset[_], joinExprs: Column): DataFrame = join(right, joinExprs, "inner")
+
+ /**
+ * Join with another `DataFrame`, using the given join expression. The following performs a full
+ * outer join between `df1` and `df2`.
+ *
+ * {{{
+ * // Scala:
+ * import org.apache.spark.sql.functions._
+ * df1.join(df2, $"df1Key" === $"df2Key", "outer")
+ *
+ * // Java:
+ * import static org.apache.spark.sql.functions.*;
+ * df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer");
+ * }}}
+ *
+ * @param right
+ * Right side of the join.
+ * @param joinExprs
+ * Join expression.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`,
+ * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`,
+ * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`,
+ * `left_anti`.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = {
+ buildJoin(right) { builder =>
+ builder
+ .setJoinType(toJoinType(joinType))
+ .setJoinCondition(joinExprs.expr)
+ }
+ }
+
+ /**
+ * Explicit cartesian join with another `DataFrame`.
+ *
+ * @param right
+ * Right side of the join operation.
+ *
+ * @note
+ * Cartesian joins are very expensive without an extra filter that can be pushed down.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def crossJoin(right: Dataset[_]): DataFrame = buildJoin(right) { builder =>
+ builder.setJoinType(proto.Join.JoinType.JOIN_TYPE_CROSS)
+ }
+
+ private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
+ sparkSession.newDataset(encoder) { builder =>
+ builder.getSortBuilder
+ .setInput(plan.getRoot)
+ .setIsGlobal(global)
+ .addAllOrder(sortExprs.map(_.sortOrder).asJava)
+ }
+ }
+
+ /**
+ * Returns a new Dataset with each partition sorted by the given expressions.
+ *
+ * This is the same operation as "SORT BY" in SQL (Hive QL).
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = {
+ sortWithinPartitions((sortCol +: sortCols).map(Column(_)): _*)
+ }
+
+ /**
+ * Returns a new Dataset with each partition sorted by the given expressions.
+ *
+ * This is the same operation as "SORT BY" in SQL (Hive QL).
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def sortWithinPartitions(sortExprs: Column*): Dataset[T] = {
+ buildSort(global = false, sortExprs)
+ }
+
+ /**
+ * Returns a new Dataset sorted by the specified column, all in ascending order.
+ * {{{
+ * // The following 3 are equivalent
+ * ds.sort("sortcol")
+ * ds.sort($"sortcol")
+ * ds.sort($"sortcol".asc)
+ * }}}
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def sort(sortCol: String, sortCols: String*): Dataset[T] = {
+ sort((sortCol +: sortCols).map(Column(_)): _*)
+ }
+
+ /**
+ * Returns a new Dataset sorted by the given expressions. For example:
+ * {{{
+ * ds.sort($"col1", $"col2".desc)
+ * }}}
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def sort(sortExprs: Column*): Dataset[T] = {
+ buildSort(global = true, sortExprs)
+ }
+
+ /**
+ * Returns a new Dataset sorted by the given expressions. This is an alias of the `sort`
+ * function.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols: _*)
+
+ /**
+ * Returns a new Dataset sorted by the given expressions. This is an alias of the `sort`
+ * function.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs: _*)
+
+ /**
+ * Selects column based on the column name and returns it as a [[Column]].
+ *
+ * @note
+ * The column name can also reference to a nested column like `a.b`.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def apply(colName: String): Column = col(colName)
+
+ /**
+ * Specifies some hint on the current Dataset. As an example, the following code specifies that
+ * one of the plan can be broadcasted:
+ *
+ * {{{
+ * df1.join(df2.hint("broadcast"))
+ * }}}
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset(encoder) {
+ builder =>
+ builder.getHintBuilder
+ .setInput(plan.getRoot)
+ .setName(name)
+ .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
+ }
+
+ private def getPlanId: Option[Long] =
+ if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) {
+ Option(plan.getRoot.getCommon.getPlanId)
+ } else {
+ None
+ }
+
+ /**
+ * Selects column based on the column name and returns it as a [[Column]].
+ *
+ * @note
+ * The column name can also reference to a nested column like `a.b`.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def col(colName: String): Column = {
+ Column.apply(colName, getPlanId)
+ }
+
+ /**
+ * Selects column based on the column name specified as a regex and returns it as [[Column]].
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def colRegex(colName: String): Column = {
+ Column { builder =>
+ val unresolvedRegexBuilder = builder.getUnresolvedRegexBuilder.setColName(colName)
+ getPlanId.foreach(unresolvedRegexBuilder.setPlanId)
+ }
+ }
+
+ /**
+ * Returns a new Dataset with an alias set.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def as(alias: String): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
+ builder.getSubqueryAliasBuilder
+ .setInput(plan.getRoot)
+ .setAlias(alias)
+ }
+
+ /**
+ * (Scala-specific) Returns a new Dataset with an alias set.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def as(alias: Symbol): Dataset[T] = as(alias.name)
+
+ /**
+ * Returns a new Dataset with an alias set. Same as `as`.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def alias(alias: String): Dataset[T] = as(alias)
+
+ /**
+ * (Scala-specific) Returns a new Dataset with an alias set. Same as `as`.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def alias(alias: Symbol): Dataset[T] = as(alias)
+
+ /**
+ * Selects a set of column based expressions.
+ * {{{
+ * ds.select($"colA", $"colB" + 1)
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def select(cols: Column*): DataFrame = sparkSession.newDataFrame { builder =>
+ builder.getProjectBuilder
+ .setInput(plan.getRoot)
+ .addAllExpressions(cols.map(_.expr).asJava)
+ }
+
+ /**
+ * Selects a set of columns. This is a variant of `select` that can only select existing columns
+ * using column names (i.e. cannot construct expressions).
+ *
+ * {{{
+ * // The following two are equivalent:
+ * ds.select("colA", "colB")
+ * ds.select($"colA", $"colB")
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)): _*)
+
+ /**
+ * Selects a set of SQL expressions. This is a variant of `select` that accepts SQL expressions.
+ *
+ * {{{
+ * // The following are equivalent:
+ * ds.selectExpr("colA", "colB as newName", "abs(colC)")
+ * ds.select(expr("colA"), expr("colB as newName"), expr("abs(colC)"))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def selectExpr(exprs: String*): DataFrame = {
+ select(exprs.map(functions.expr): _*)
+ }
+
+ /**
+ * Returns a new Dataset by computing the given [[Column]] expression for each element.
+ *
+ * {{{
+ * val ds = Seq(1, 2, 3).toDS()
+ * val newDS = ds.select(expr("value + 1").as[Int])
+ * }}}
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
+ val encoder = c1.encoder
+ val expr = if (encoder.schema == encoder.dataType) {
+ functions.inline(functions.array(c1)).expr
+ } else {
+ c1.expr
+ }
+ sparkSession.newDataset(encoder) { builder =>
+ builder.getProjectBuilder
+ .setInput(plan.getRoot)
+ .addExpressions(expr)
+ }
+ }
+
+ /**
+ * Filters rows using the given condition.
+ * {{{
+ * // The following are equivalent:
+ * peopleDs.filter($"age" > 15)
+ * peopleDs.where($"age" > 15)
+ * }}}
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def filter(condition: Column): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
+ builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
+ }
+
+ /**
+ * Filters rows using the given SQL expression.
+ * {{{
+ * peopleDs.filter("age > 15")
+ * }}}
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def filter(conditionExpr: String): Dataset[T] = filter(functions.expr(conditionExpr))
+
+ /**
+ * Filters rows using the given condition. This is an alias for `filter`.
+ * {{{
+ * // The following are equivalent:
+ * peopleDs.filter($"age" > 15)
+ * peopleDs.where($"age" > 15)
+ * }}}
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def where(condition: Column): Dataset[T] = filter(condition)
+
+ /**
+ * Filters rows using the given SQL expression.
+ * {{{
+ * peopleDs.where("age > 15")
+ * }}}
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def where(conditionExpr: String): Dataset[T] = filter(conditionExpr)
+
+ private def buildUnpivot(
+ ids: Array[Column],
+ valuesOption: Option[Array[Column]],
+ variableColumnName: String,
+ valueColumnName: String): DataFrame = sparkSession.newDataFrame { builder =>
+ val unpivot = builder.getUnpivotBuilder
+ .setInput(plan.getRoot)
+ .addAllIds(ids.toSeq.map(_.expr).asJava)
+ .setValueColumnName(variableColumnName)
+ .setValueColumnName(valueColumnName)
+ valuesOption.foreach { values =>
+ unpivot.getValuesBuilder
+ .addAllValues(values.toSeq.map(_.expr).asJava)
+ }
+ }
+
+ /**
+ * Groups the Dataset using the specified columns, so we can run aggregation on them. See
+ * [[RelationalGroupedDataset]] for all the available aggregate functions.
+ *
+ * {{{
+ * // Compute the average for all numeric columns grouped by department.
+ * ds.groupBy($"department").avg()
+ *
+ * // Compute the max age and average salary, grouped by department and gender.
+ * ds.groupBy($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def groupBy(cols: Column*): RelationalGroupedDataset = {
+ new RelationalGroupedDataset(
+ toDF(),
+ cols.map(_.expr),
+ proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+ }
+
+ /**
+ * Groups the Dataset using the specified columns, so that we can run aggregation on them. See
+ * [[RelationalGroupedDataset]] for all the available aggregate functions.
+ *
+ * This is a variant of groupBy that can only group by existing columns using column names (i.e.
+ * cannot construct expressions).
+ *
+ * {{{
+ * // Compute the average for all numeric columns grouped by department.
+ * ds.groupBy("department").avg()
+ *
+ * // Compute the max age and average salary, grouped by department and gender.
+ * ds.groupBy($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def groupBy(col1: String, cols: String*): RelationalGroupedDataset = {
+ val colNames: Seq[String] = col1 +: cols
+ new RelationalGroupedDataset(
+ toDF(),
+ colNames.map(colName => Column(colName).expr),
+ proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+ }
+
+ /**
+ * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we
+ * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate
+ * functions.
+ *
+ * {{{
+ * // Compute the average for all numeric columns rolled up by department and group.
+ * ds.rollup($"department", $"group").avg()
+ *
+ * // Compute the max age and average salary, rolled up by department and gender.
+ * ds.rollup($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def rollup(cols: Column*): RelationalGroupedDataset = {
+ new RelationalGroupedDataset(
+ toDF(),
+ cols.map(_.expr),
+ proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
+ }
+
+ /**
+ * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we
+ * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate
+ * functions.
+ *
+ * This is a variant of rollup that can only group by existing columns using column names (i.e.
+ * cannot construct expressions).
+ *
+ * {{{
+ * // Compute the average for all numeric columns rolled up by department and group.
+ * ds.rollup("department", "group").avg()
+ *
+ * // Compute the max age and average salary, rolled up by department and gender.
+ * ds.rollup($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def rollup(col1: String, cols: String*): RelationalGroupedDataset = {
+ val colNames: Seq[String] = col1 +: cols
+ new RelationalGroupedDataset(
+ toDF(),
+ colNames.map(colName => Column(colName).expr),
+ proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
+ }
+
+ /**
+ * Create a multi-dimensional cube for the current Dataset using the specified columns, so we
+ * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate
+ * functions.
+ *
+ * {{{
+ * // Compute the average for all numeric columns cubed by department and group.
+ * ds.cube($"department", $"group").avg()
+ *
+ * // Compute the max age and average salary, cubed by department and gender.
+ * ds.cube($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def cube(cols: Column*): RelationalGroupedDataset = {
+ new RelationalGroupedDataset(
+ toDF(),
+ cols.map(_.expr),
+ proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
+ }
+
+ /**
+ * Create a multi-dimensional cube for the current Dataset using the specified columns, so we
+ * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate
+ * functions.
+ *
+ * This is a variant of cube that can only group by existing columns using column names (i.e.
+ * cannot construct expressions).
+ *
+ * {{{
+ * // Compute the average for all numeric columns cubed by department and group.
+ * ds.cube("department", "group").avg()
+ *
+ * // Compute the max age and average salary, cubed by department and gender.
+ * ds.cube($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def cube(col1: String, cols: String*): RelationalGroupedDataset = {
+ val colNames: Seq[String] = col1 +: cols
+ new RelationalGroupedDataset(
+ toDF(),
+ colNames.map(colName => Column(colName).expr),
+ proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
+ }
+
+ /**
+ * (Scala-specific) Aggregates on the entire Dataset without groups.
+ * {{{
+ * // ds.agg(...) is a shorthand for ds.groupBy().agg(...)
+ * ds.agg("age" -> "max", "salary" -> "avg")
+ * ds.groupBy().agg("age" -> "max", "salary" -> "avg")
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
+ groupBy().agg(aggExpr, aggExprs: _*)
+ }
+
+ /**
+ * (Scala-specific) Aggregates on the entire Dataset without groups.
+ * {{{
+ * // ds.agg(...) is a shorthand for ds.groupBy().agg(...)
+ * ds.agg(Map("age" -> "max", "salary" -> "avg"))
+ * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
+
+ /**
+ * (Java-specific) Aggregates on the entire Dataset without groups.
+ * {{{
+ * // ds.agg(...) is a shorthand for ds.groupBy().agg(...)
+ * ds.agg(Map("age" -> "max", "salary" -> "avg"))
+ * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs)
+
+ /**
+ * Aggregates on the entire Dataset without groups.
+ * {{{
+ * // ds.agg(...) is a shorthand for ds.groupBy().agg(...)
+ * ds.agg(max($"age"), avg($"salary"))
+ * ds.groupBy().agg(max($"age"), avg($"salary"))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs: _*)
+
+ /**
+ * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
+ * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
+ * which cannot be reversed.
+ *
+ * This function is useful to massage a DataFrame into a format where some columns are
+ * identifier columns ("ids"), while all other columns ("values") are "unpivoted" to the rows,
+ * leaving just two non-id columns, named as given by `variableColumnName` and
+ * `valueColumnName`.
+ *
+ * {{{
+ * val df = Seq((1, 11, 12L), (2, 21, 22L)).toDF("id", "int", "long")
+ * df.show()
+ * // output:
+ * // +---+---+----+
+ * // | id|int|long|
+ * // +---+---+----+
+ * // | 1| 11| 12|
+ * // | 2| 21| 22|
+ * // +---+---+----+
+ *
+ * df.unpivot(Array($"id"), Array($"int", $"long"), "variable", "value").show()
+ * // output:
+ * // +---+--------+-----+
+ * // | id|variable|value|
+ * // +---+--------+-----+
+ * // | 1| int| 11|
+ * // | 1| long| 12|
+ * // | 2| int| 21|
+ * // | 2| long| 22|
+ * // +---+--------+-----+
+ * // schema:
+ * //root
+ * // |-- id: integer (nullable = false)
+ * // |-- variable: string (nullable = false)
+ * // |-- value: long (nullable = true)
+ * }}}
+ *
+ * When no "id" columns are given, the unpivoted DataFrame consists of only the "variable" and
+ * "value" columns.
+ *
+ * All "value" columns must share a least common data type. Unless they are the same data type,
+ * all "value" columns are cast to the nearest common data type. For instance, types
+ * `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType` do
+ * not have a common data type and `unpivot` fails with an `AnalysisException`.
+ *
+ * @param ids
+ * Id columns
+ * @param values
+ * Value columns to unpivot
+ * @param variableColumnName
+ * Name of the variable column
+ * @param valueColumnName
+ * Name of the value column
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def unpivot(
+ ids: Array[Column],
+ values: Array[Column],
+ variableColumnName: String,
+ valueColumnName: String): DataFrame = {
+ buildUnpivot(ids, Option(values), variableColumnName, valueColumnName)
+ }
+
+ /**
+ * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
+ * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
+ * which cannot be reversed.
+ *
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
+ *
+ * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` where `values`
+ * is set to all non-id columns that exist in the DataFrame.
+ *
+ * @param ids
+ * Id columns
+ * @param variableColumnName
+ * Name of the variable column
+ * @param valueColumnName
+ * Name of the value column
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def unpivot(
+ ids: Array[Column],
+ variableColumnName: String,
+ valueColumnName: String): DataFrame = {
+ buildUnpivot(ids, None, variableColumnName, valueColumnName)
+ }
+
+ /**
+ * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
+ * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
+ * which cannot be reversed. This is an alias for `unpivot`.
+ *
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
+ *
+ * @param ids
+ * Id columns
+ * @param values
+ * Value columns to unpivot
+ * @param variableColumnName
+ * Name of the variable column
+ * @param valueColumnName
+ * Name of the value column
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def melt(
+ ids: Array[Column],
+ values: Array[Column],
+ variableColumnName: String,
+ valueColumnName: String): DataFrame =
+ unpivot(ids, values, variableColumnName, valueColumnName)
+
+ /**
+ * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
+ * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
+ * which cannot be reversed. This is an alias for `unpivot`.
+ *
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
+ *
+ * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` where `values`
+ * is set to all non-id columns that exist in the DataFrame.
+ *
+ * @param ids
+ * Id columns
+ * @param variableColumnName
+ * Name of the variable column
+ * @param valueColumnName
+ * Name of the value column
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def melt(ids: Array[Column], variableColumnName: String, valueColumnName: String): DataFrame =
+ unpivot(ids, variableColumnName, valueColumnName)
+
+ /**
+ * Returns a new Dataset by taking the first `n` rows. The difference between this function and
+ * `head` is that `head` is an action and returns an array (by triggering query execution) while
+ * `limit` returns a new Dataset.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def limit(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
+ builder.getLimitBuilder
+ .setInput(plan.getRoot)
+ .setLimit(n)
+ }
+
+ /**
+ * Returns a new Dataset by skipping the first `n` rows.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def offset(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
+ builder.getOffsetBuilder
+ .setInput(plan.getRoot)
+ .setOffset(n)
+ }
+
+ private def buildSetOp(right: Dataset[T], setOpType: proto.SetOperation.SetOpType)(
+ f: proto.SetOperation.Builder => Unit): Dataset[T] = {
+ sparkSession.newDataset(encoder) { builder =>
+ f(
+ builder.getSetOpBuilder
+ .setSetOpType(setOpType)
+ .setLeftInput(plan.getRoot)
+ .setRightInput(right.plan.getRoot))
+ }
+ }
+
+ /**
+ * Returns a new Dataset containing union of rows in this Dataset and another Dataset.
+ *
+ * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does
+ * deduplication of elements), use this function followed by a [[distinct]].
+ *
+ * Also as standard in SQL, this function resolves columns by position (not by name):
+ *
+ * {{{
+ * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2")
+ * val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0")
+ * df1.union(df2).show
+ *
+ * // output:
+ * // +----+----+----+
+ * // |col0|col1|col2|
+ * // +----+----+----+
+ * // | 1| 2| 3|
+ * // | 4| 5| 6|
+ * // +----+----+----+
+ * }}}
+ *
+ * Notice that the column positions in the schema aren't necessarily matched with the fields in
+ * the strongly typed objects in a Dataset. This function resolves columns by their positions in
+ * the schema, not the fields in the strongly typed objects. Use [[unionByName]] to resolve
+ * columns by field name in the typed objects.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def union(other: Dataset[T]): Dataset[T] = {
+ buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_UNION) { builder =>
+ builder.setIsAll(true)
+ }
+ }
+
+ /**
+ * Returns a new Dataset containing union of rows in this Dataset and another Dataset. This is
+ * an alias for `union`.
+ *
+ * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does
+ * deduplication of elements), use this function followed by a [[distinct]].
+ *
+ * Also as standard in SQL, this function resolves columns by position (not by name).
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def unionAll(other: Dataset[T]): Dataset[T] = union(other)
+
+ /**
+ * Returns a new Dataset containing union of rows in this Dataset and another Dataset.
+ *
+ * This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set
+ * union (that does deduplication of elements), use this function followed by a [[distinct]].
+ *
+ * The difference between this function and [[union]] is that this function resolves columns by
+ * name (not by position):
+ *
+ * {{{
+ * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2")
+ * val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0")
+ * df1.unionByName(df2).show
+ *
+ * // output:
+ * // +----+----+----+
+ * // |col0|col1|col2|
+ * // +----+----+----+
+ * // | 1| 2| 3|
+ * // | 6| 4| 5|
+ * // +----+----+----+
+ * }}}
+ *
+ * Note that this supports nested columns in struct and array types. Nested columns in map types
+ * are not currently supported.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def unionByName(other: Dataset[T]): Dataset[T] = unionByName(other, allowMissingColumns = false)
+
+ /**
+ * Returns a new Dataset containing union of rows in this Dataset and another Dataset.
+ *
+ * The difference between this function and [[union]] is that this function resolves columns by
+ * name (not by position).
+ *
+ * When the parameter `allowMissingColumns` is `true`, the set of column names in this and other
+ * `Dataset` can differ; missing columns will be filled with null. Further, the missing columns
+ * of this `Dataset` will be added at the end in the schema of the union result:
+ *
+ * {{{
+ * val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2")
+ * val df2 = Seq((4, 5, 6)).toDF("col1", "col0", "col3")
+ * df1.unionByName(df2, true).show
+ *
+ * // output: "col3" is missing at left df1 and added at the end of schema.
+ * // +----+----+----+----+
+ * // |col0|col1|col2|col3|
+ * // +----+----+----+----+
+ * // | 1| 2| 3|null|
+ * // | 5| 4|null| 6|
+ * // +----+----+----+----+
+ *
+ * df2.unionByName(df1, true).show
+ *
+ * // output: "col2" is missing at left df2 and added at the end of schema.
+ * // +----+----+----+----+
+ * // |col1|col0|col3|col2|
+ * // +----+----+----+----+
+ * // | 4| 5| 6|null|
+ * // | 2| 1|null| 3|
+ * // +----+----+----+----+
+ * }}}
+ *
+ * Note that this supports nested columns in struct and array types. With `allowMissingColumns`,
+ * missing nested columns of struct columns with the same name will also be filled with null
+ * values and added to the end of struct. Nested columns in map types are not currently
+ * supported.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = {
+ buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_UNION) { builder =>
+ builder.setByName(true).setIsAll(true).setAllowMissingColumns(allowMissingColumns)
+ }
+ }
+
+ /**
+ * Returns a new Dataset containing rows only in both this Dataset and another Dataset. This is
+ * equivalent to `INTERSECT` in SQL.
+ *
+ * @note
+ * Equality checking is performed directly on the encoded representation of the data and thus
+ * is not affected by a custom `equals` function defined on `T`.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def intersect(other: Dataset[T]): Dataset[T] = {
+ buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_INTERSECT) { builder =>
+ builder.setIsAll(false)
+ }
+ }
+
+ /**
+ * Returns a new Dataset containing rows only in both this Dataset and another Dataset while
+ * preserving the duplicates. This is equivalent to `INTERSECT ALL` in SQL.
+ *
+ * @note
+ * Equality checking is performed directly on the encoded representation of the data and thus
+ * is not affected by a custom `equals` function defined on `T`. Also as standard in SQL, this
+ * function resolves columns by position (not by name).
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def intersectAll(other: Dataset[T]): Dataset[T] = {
+ buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_INTERSECT) { builder =>
+ builder.setIsAll(true)
+ }
+ }
+
+ /**
+ * Returns a new Dataset containing rows in this Dataset but not in another Dataset. This is
+ * equivalent to `EXCEPT DISTINCT` in SQL.
+ *
+ * @note
+ * Equality checking is performed directly on the encoded representation of the data and thus
+ * is not affected by a custom `equals` function defined on `T`.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def except(other: Dataset[T]): Dataset[T] = {
+ buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_EXCEPT) { builder =>
+ builder.setIsAll(false)
+ }
+ }
+
+ /**
+ * Returns a new Dataset containing rows in this Dataset but not in another Dataset while
+ * preserving the duplicates. This is equivalent to `EXCEPT ALL` in SQL.
+ *
+ * @note
+ * Equality checking is performed directly on the encoded representation of the data and thus
+ * is not affected by a custom `equals` function defined on `T`. Also as standard in SQL, this
+ * function resolves columns by position (not by name).
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def exceptAll(other: Dataset[T]): Dataset[T] = {
+ buildSetOp(other, proto.SetOperation.SetOpType.SET_OP_TYPE_EXCEPT) { builder =>
+ builder.setIsAll(true)
+ }
+ }
+
+ /**
+ * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a
+ * user-supplied seed.
+ *
+ * @param fraction
+ * Fraction of rows to generate, range [0.0, 1.0].
+ * @param seed
+ * Seed for sampling.
+ *
+ * @note
+ * This is NOT guaranteed to provide exactly the fraction of the count of the given
+ * [[Dataset]].
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def sample(fraction: Double, seed: Long): Dataset[T] = {
+ sample(withReplacement = false, fraction = fraction, seed = seed)
+ }
+
+ /**
+ * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a
+ * random seed.
+ *
+ * @param fraction
+ * Fraction of rows to generate, range [0.0, 1.0].
+ *
+ * @note
+ * This is NOT guaranteed to provide exactly the fraction of the count of the given
+ * [[Dataset]].
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def sample(fraction: Double): Dataset[T] = {
+ sample(withReplacement = false, fraction = fraction)
+ }
+
+ /**
+ * Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed.
+ *
+ * @param withReplacement
+ * Sample with replacement or not.
+ * @param fraction
+ * Fraction of rows to generate, range [0.0, 1.0].
+ * @param seed
+ * Seed for sampling.
+ *
+ * @note
+ * This is NOT guaranteed to provide exactly the fraction of the count of the given
+ * [[Dataset]].
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
+ sparkSession.newDataset(encoder) { builder =>
+ builder.getSampleBuilder
+ .setInput(plan.getRoot)
+ .setWithReplacement(withReplacement)
+ .setLowerBound(0.0d)
+ .setUpperBound(fraction)
+ .setSeed(seed)
+ }
+ }
+
+ /**
+ * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed.
+ *
+ * @param withReplacement
+ * Sample with replacement or not.
+ * @param fraction
+ * Fraction of rows to generate, range [0.0, 1.0].
+ *
+ * @note
+ * This is NOT guaranteed to provide exactly the fraction of the total count of the given
+ * [[Dataset]].
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = {
+ sample(withReplacement, fraction, Utils.random.nextLong)
+ }
+
+ /**
+ * Randomly splits this Dataset with the provided weights.
+ *
+ * @param weights
+ * weights for splits, will be normalized if they don't sum to 1.
+ * @param seed
+ * Seed for sampling.
+ *
+ * For Java API, use [[randomSplitAsList]].
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = {
+ require(
+ weights.forall(_ >= 0),
+ s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}")
+ require(
+ weights.sum > 0,
+ s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}")
+
+ // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
+ // constituent partitions each time a split is materialized which could result in
+ // overlapping splits. To prevent this, we explicitly sort each input partition to make the
+ // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out
+ // from the sort order.
+ // TODO we need to have a proper way of stabilizing the input data. The current approach does
+ // not work well with spark connects' extremely lazy nature. When the schema is modified
+ // between construction and execution the query might fail or produce wrong results. Another
+ // problem can come from data that arrives between the execution of the returned datasets.
+ val sortOrder = schema.collect {
+ case f if RowOrdering.isOrderable(f.dataType) => col(f.name).asc
+ }
+ val sortedInput = sortWithinPartitions(sortOrder: _*).plan.getRoot
+ val sum = weights.sum
+ val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
+ normalizedCumWeights
+ .sliding(2)
+ .map { case Array(low, high) =>
+ sparkSession.newDataset(encoder) { builder =>
+ builder.getSampleBuilder
+ .setInput(sortedInput)
+ .setWithReplacement(false)
+ .setLowerBound(low)
+ .setUpperBound(high)
+ .setSeed(seed)
+ }
+ }
+ .toArray
+ }
+
+ /**
+ * Returns a Java list that contains randomly split Dataset with the provided weights.
+ *
+ * @param weights
+ * weights for splits, will be normalized if they don't sum to 1.
+ * @param seed
+ * Seed for sampling.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = {
+ val values = randomSplit(weights, seed)
+ java.util.Arrays.asList(values: _*)
+ }
+
+ /**
+ * Randomly splits this Dataset with the provided weights.
+ *
+ * @param weights
+ * weights for splits, will be normalized if they don't sum to 1.
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def randomSplit(weights: Array[Double]): Array[Dataset[T]] = {
+ randomSplit(weights, Utils.random.nextLong)
+ }
+
+ private def withColumns(names: Seq[String], values: Seq[Column]): DataFrame = {
+ val aliases = values.zip(names).map { case (value, name) =>
+ value.name(name).expr.getAlias
+ }
+ sparkSession.newDataFrame { builder =>
+ builder.getWithColumnsBuilder
+ .setInput(plan.getRoot)
+ .addAllAliases(aliases.asJava)
+ }
+ }
+
+ /**
+ * Returns a new Dataset by adding a column or replacing the existing column that has the same
+ * name.
+ *
+ * `column`'s expression must only refer to attributes supplied by this Dataset. It is an error
+ * to add a column that refers to some other Dataset.
+ *
+ * @note
+ * this method introduces a projection internally. Therefore, calling it multiple times, for
+ * instance, via loops in order to add multiple columns can generate big plans which can cause
+ * performance issues and even `StackOverflowException`. To avoid this, use `select` with the
+ * multiple columns at once.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col))
+
+ /**
+ * (Scala-specific) Returns a new Dataset by adding columns or replacing the existing columns
+ * that has the same names.
+ *
+ * `colsMap` is a map of column name and column, the column must only refer to attributes
+ * supplied by this Dataset. It is an error to add columns that refers to some other Dataset.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def withColumns(colsMap: Map[String, Column]): DataFrame = {
+ val (colNames, newCols) = colsMap.toSeq.unzip
+ withColumns(colNames, newCols)
+ }
+
+ /**
+ * (Java-specific) Returns a new Dataset by adding columns or replacing the existing columns
+ * that has the same names.
+ *
+ * `colsMap` is a map of column name and column, the column must only refer to attribute
+ * supplied by this Dataset. It is an error to add columns that refers to some other Dataset.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def withColumns(colsMap: java.util.Map[String, Column]): DataFrame = withColumns(
+ colsMap.asScala.toMap)
+
+ /**
+ * Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain
+ * existingName.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def withColumnRenamed(existingName: String, newName: String): DataFrame = {
+ withColumnsRenamed(Collections.singletonMap(existingName, newName))
+ }
+
+ /**
+ * (Scala-specific) Returns a new Dataset with a columns renamed. This is a no-op if schema
+ * doesn't contain existingName.
+ *
+ * `colsMap` is a map of existing column name and new column name.
+ *
+ * @throws AnalysisException
+ * if there are duplicate names in resulting projection
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @throws[AnalysisException]
+ def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = {
+ withColumnsRenamed(colsMap.asJava)
+ }
+
+ /**
+ * (Java-specific) Returns a new Dataset with a columns renamed. This is a no-op if schema
+ * doesn't contain existingName.
+ *
+ * `colsMap` is a map of existing column name and new column name.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def withColumnsRenamed(colsMap: java.util.Map[String, String]): DataFrame = {
+ sparkSession.newDataFrame { builder =>
+ builder.getWithColumnsRenamedBuilder
+ .setInput(plan.getRoot)
+ .putAllRenameColumnsMap(colsMap)
+ }
+ }
+
+ /**
+ * Returns a new Dataset by updating an existing column with metadata.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def withMetadata(columnName: String, metadata: Metadata): DataFrame = {
+ val newAlias = proto.Expression.Alias
+ .newBuilder()
+ .setExpr(col(columnName).expr)
+ .addName(columnName)
+ .setMetadata(metadata.json)
+ sparkSession.newDataFrame { builder =>
+ builder.getWithColumnsBuilder
+ .setInput(plan.getRoot)
+ .addAliases(newAlias)
+ }
+ }
+
+ /**
+ * Registers this Dataset as a temporary table using the given name. The lifetime of this
+ * temporary table is tied to the [[SparkSession]] that was used to create this Dataset.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ @deprecated("Use createOrReplaceTempView(viewName) instead.", "3.4.0")
+ def registerTempTable(tableName: String): Unit = {
+ createOrReplaceTempView(tableName)
+ }
+
+ /**
+ * Creates a local temporary view using the given name. The lifetime of this temporary view is
+ * tied to the [[SparkSession]] that was used to create this Dataset.
+ *
+ * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that
+ * created it, i.e. it will be automatically dropped when the session terminates. It's not tied
+ * to any databases, i.e. we can't use `db1.view1` to reference a local temporary view.
+ *
+ * @throws AnalysisException
+ * if the view name is invalid or already exists
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ @throws[AnalysisException]
+ def createTempView(viewName: String): Unit = {
+ buildAndExecuteTempView(viewName, replace = false, global = false)
+ }
+
+ /**
+ * Creates a local temporary view using the given name. The lifetime of this temporary view is
+ * tied to the [[SparkSession]] that was used to create this Dataset.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def createOrReplaceTempView(viewName: String): Unit = {
+ buildAndExecuteTempView(viewName, replace = true, global = false)
+ }
+
+ /**
+ * Creates a global temporary view using the given name. The lifetime of this temporary view is
+ * tied to this Spark application.
+ *
+ * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark
+ * application,
+ * i.e. it will be automatically dropped when the application terminates. It's tied to a system
+ * preserved database `global_temp`, and we must use the qualified name to refer a global temp
+ * view, e.g. `SELECT * FROM global_temp.view1`.
+ *
+ * @throws AnalysisException
+ * if the view name is invalid or already exists
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ @throws[AnalysisException]
+ def createGlobalTempView(viewName: String): Unit = {
+ buildAndExecuteTempView(viewName, replace = false, global = true)
+ }
+
+ /**
+ * Creates or replaces a global temporary view using the given name. The lifetime of this
+ * temporary view is tied to this Spark application.
+ *
+ * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark
+ * application,
+ * i.e. it will be automatically dropped when the application terminates. It's tied to a system
+ * preserved database `global_temp`, and we must use the qualified name to refer a global temp
+ * view, e.g. `SELECT * FROM global_temp.view1`.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def createOrReplaceGlobalTempView(viewName: String): Unit = {
+ buildAndExecuteTempView(viewName, replace = true, global = true)
+ }
+
+ private def buildAndExecuteTempView(
+ viewName: String,
+ replace: Boolean,
+ global: Boolean): Unit = {
+ val command = sparkSession.newCommand { builder =>
+ builder.getCreateDataframeViewBuilder
+ .setInput(plan.getRoot)
+ .setName(viewName)
+ .setIsGlobal(global)
+ .setReplace(replace)
+ }
+ sparkSession.execute(command)
+ }
+
+ /**
+ * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column
+ * name.
+ *
+ * This method can only be used to drop top level columns. the colName string is treated
+ * literally without further interpretation.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def drop(colName: String): DataFrame = {
+ drop(Seq(colName): _*)
+ }
+
+ /**
+ * Returns a new Dataset with columns dropped. This is a no-op if schema doesn't contain column
+ * name(s).
+ *
+ * This method can only be used to drop top level columns. the colName string is treated
+ * literally without further interpretation.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def drop(colNames: String*): DataFrame = buildDropByNames(colNames)
+
+ /**
+ * Returns a new Dataset with column dropped.
+ *
+ * This method can only be used to drop top level column. This version of drop accepts a
+ * [[Column]] rather than a name. This is a no-op if the Dataset doesn't have a column with an
+ * equivalent expression.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def drop(col: Column): DataFrame = {
+ buildDrop(col :: Nil)
+ }
+
+ /**
+ * Returns a new Dataset with columns dropped.
+ *
+ * This method can only be used to drop top level columns. This is a no-op if the Dataset
+ * doesn't have a columns with an equivalent expression.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def drop(col: Column, cols: Column*): DataFrame = buildDrop(col +: cols)
+
+ private def buildDrop(cols: Seq[Column]): DataFrame = sparkSession.newDataFrame { builder =>
+ builder.getDropBuilder
+ .setInput(plan.getRoot)
+ .addAllColumns(cols.map(_.expr).asJava)
+ }
+
+ private def buildDropByNames(cols: Seq[String]): DataFrame = sparkSession.newDataFrame {
+ builder =>
+ builder.getDropBuilder
+ .setInput(plan.getRoot)
+ .addAllColumnNames(cols.asJava)
+ }
+
+ /**
+ * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias
+ * for `distinct`.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def dropDuplicates(): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
+ builder.getDeduplicateBuilder
+ .setInput(plan.getRoot)
+ .setAllColumnsAsKeys(true)
+ }
+
+ /**
+ * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only the
+ * subset of columns.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def dropDuplicates(colNames: Seq[String]): Dataset[T] = sparkSession.newDataset(encoder) {
+ builder =>
+ builder.getDeduplicateBuilder
+ .setInput(plan.getRoot)
+ .addAllColumnNames(colNames.asJava)
+ }
+
+ /**
+ * Returns a new Dataset with duplicate rows removed, considering only the subset of columns.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq)
+
+ /**
+ * Returns a new [[Dataset]] with duplicate rows removed, considering only the subset of
+ * columns.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def dropDuplicates(col1: String, cols: String*): Dataset[T] = {
+ val colNames: Seq[String] = col1 +: cols
+ dropDuplicates(colNames)
+ }
+
+ /**
+ * Computes basic statistics for numeric and string columns, including count, mean, stddev, min,
+ * and max. If no columns are given, this function computes statistics for all numerical or
+ * string columns.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting Dataset. If you want to
+ * programmatically compute summary statistics, use the `agg` function instead.
+ *
+ * {{{
+ * ds.describe("age", "height").show()
+ *
+ * // output:
+ * // summary age height
+ * // count 10.0 10.0
+ * // mean 53.3 178.05
+ * // stddev 11.6 15.7
+ * // min 18.0 163.0
+ * // max 92.0 192.0
+ * }}}
+ *
+ * Use [[summary]] for expanded statistics and control over which statistics to compute.
+ *
+ * @param cols
+ * Columns to compute statistics on.
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def describe(cols: String*): DataFrame = sparkSession.newDataFrame { builder =>
+ builder.getDescribeBuilder
+ .setInput(plan.getRoot)
+ .addAllCols(cols.asJava)
+ }
+
+ /**
+ * Computes specified statistics for numeric and string columns. Available statistics are:
+ *
count
mean
stddev
min
max
arbitrary
+ * approximate percentiles specified as a percentage (e.g. 75%)
count_distinct
+ *
approx_count_distinct
+ *
+ * If no statistics are given, this function computes count, mean, stddev, min, approximate
+ * quartiles (percentiles at 25%, 50%, and 75%), and max.
+ *
+ * This function is meant for exploratory data analysis, as we make no guarantee about the
+ * backward compatibility of the schema of the resulting Dataset. If you want to
+ * programmatically compute summary statistics, use the `agg` function instead.
+ *
+ * {{{
+ * ds.summary().show()
+ *
+ * // output:
+ * // summary age height
+ * // count 10.0 10.0
+ * // mean 53.3 178.05
+ * // stddev 11.6 15.7
+ * // min 18.0 163.0
+ * // 25% 24.0 176.0
+ * // 50% 24.0 176.0
+ * // 75% 32.0 180.0
+ * // max 92.0 192.0
+ * }}}
+ *
+ * {{{
+ * ds.summary("count", "min", "25%", "75%", "max").show()
+ *
+ * // output:
+ * // summary age height
+ * // count 10.0 10.0
+ * // min 18.0 163.0
+ * // 25% 24.0 176.0
+ * // 75% 32.0 180.0
+ * // max 92.0 192.0
+ * }}}
+ *
+ * To do a summary for specific columns first select them:
+ *
+ * {{{
+ * ds.select("age", "height").summary().show()
+ * }}}
+ *
+ * Specify statistics to output custom summaries:
+ *
+ * {{{
+ * ds.summary("count", "count_distinct").show()
+ * }}}
+ *
+ * The distinct count isn't included by default.
+ *
+ * You can also run approximate distinct counts which are faster:
+ *
+ * {{{
+ * ds.summary("count", "approx_count_distinct").show()
+ * }}}
+ *
+ * See also [[describe]] for basic statistics.
+ *
+ * @param statistics
+ * Statistics from above list to be computed.
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def summary(statistics: String*): DataFrame = sparkSession.newDataFrame { builder =>
+ builder.getSummaryBuilder
+ .setInput(plan.getRoot)
+ .addAllStatistics(statistics.asJava)
+ }
+
+ /**
+ * Returns the first `n` rows.
+ *
+ * @note
+ * this method should only be used if the resulting array is expected to be small, as all the
+ * data is loaded into the driver's memory.
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ def head(n: Int): Array[T] = limit(n).collect()
+
+ /**
+ * Returns the first row.
+ * @group action
+ * @since 3.4.0
+ */
+ def head(): T = head(1).head
+
+ /**
+ * Returns the first row. Alias for head().
+ * @group action
+ * @since 3.4.0
+ */
+ def first(): T = head()
+
+ /**
+ * Concise syntax for chaining custom transformations.
+ * {{{
+ * def featurize(ds: Dataset[T]): Dataset[U] = ...
+ *
+ * ds
+ * .transform(featurize)
+ * .transform(...)
+ * }}}
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
+
+ /**
+ * Returns the first `n` rows in the Dataset.
+ *
+ * Running take requires moving data into the application's driver process, and doing so with a
+ * very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ def take(n: Int): Array[T] = head(n)
+
+ /**
+ * Returns the last `n` rows in the Dataset.
+ *
+ * Running tail requires moving data into the application's driver process, and doing so with a
+ * very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ def tail(n: Int): Array[T] = {
+ val lastN = sparkSession.newDataset(encoder) { builder =>
+ builder.getTailBuilder
+ .setInput(plan.getRoot)
+ .setLimit(n)
+ }
+ lastN.collect()
+ }
+
+ /**
+ * Returns the first `n` rows in the Dataset as a list.
+ *
+ * Running take requires moving data into the application's driver process, and doing so with a
+ * very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n): _*)
+
+ /**
+ * Returns an array that contains all rows in this Dataset.
+ *
+ * Running collect requires moving all the data into the application's driver process, and doing
+ * so on a very large dataset can crash the driver process with OutOfMemoryError.
+ *
+ * For Java API, use [[collectAsList]].
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ def collect(): Array[T] = withResult { result =>
+ result.toArray
+ }
+
+ /**
+ * Returns a Java list that contains all rows in this Dataset.
+ *
+ * Running collect requires moving all the data into the application's driver process, and doing
+ * so on a very large dataset can crash the driver process with OutOfMemoryError.
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ def collectAsList(): java.util.List[T] = {
+ java.util.Arrays.asList(collect(): _*)
+ }
+
+ /**
+ * Returns an iterator that contains all rows in this Dataset.
+ *
+ * The returned iterator implements [[AutoCloseable]]. For memory management it is better to
+ * close it once you are done. If you don't close it, it and the underlying data will be cleaned
+ * up once the iterator is garbage collected.
+ *
+ * @group action
+ * @since 3.4.0
+ */
+ def toLocalIterator(): java.util.Iterator[T] = {
+ // TODO make this a destructive iterator.
+ collectResult().iterator
+ }
+
+ /**
+ * Returns the number of rows in the Dataset.
+ * @group action
+ * @since 3.4.0
+ */
+ def count(): Long = {
+ groupBy().count().as(PrimitiveLongEncoder).collect().head
+ }
+
+ private def buildRepartition(numPartitions: Int, shuffle: Boolean): Dataset[T] = {
+ sparkSession.newDataset(encoder) { builder =>
+ builder.getRepartitionBuilder
+ .setInput(plan.getRoot)
+ .setNumPartitions(numPartitions)
+ .setShuffle(shuffle)
+ }
+ }
+
+ private def buildRepartitionByExpression(
+ numPartitions: Option[Int],
+ partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
+ val repartitionBuilder = builder.getRepartitionByExpressionBuilder
+ .setInput(plan.getRoot)
+ .addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
+ numPartitions.foreach(repartitionBuilder.setNumPartitions)
+ }
+
+ /**
+ * Returns a new Dataset that has exactly `numPartitions` partitions.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def repartition(numPartitions: Int): Dataset[T] = {
+ buildRepartition(numPartitions, shuffle = true)
+ }
+
+ private def repartitionByExpression(
+ numPartitions: Option[Int],
+ partitionExprs: Seq[Column]): Dataset[T] = {
+ // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments.
+ // However, we don't want to complicate the semantics of this API method.
+ // Instead, let's give users a friendly error message, pointing them to the new method.
+ val sortOrders = partitionExprs.filter(_.expr.hasSortOrder)
+ if (sortOrders.nonEmpty) {
+ throw new IllegalArgumentException(
+ s"Invalid partitionExprs specified: $sortOrders\n" +
+ s"For range partitioning use repartitionByRange(...) instead.")
+ }
+ buildRepartitionByExpression(numPartitions, partitionExprs)
+ }
+
+ /**
+ * Returns a new Dataset partitioned by the given partitioning expressions into `numPartitions`.
+ * The resulting Dataset is hash partitioned.
+ *
+ * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
+ repartitionByExpression(Some(numPartitions), partitionExprs)
+ }
+
+ /**
+ * Returns a new Dataset partitioned by the given partitioning expressions, using
+ * `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash
+ * partitioned.
+ *
+ * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
+ *
+ * @group typedrel
+ * @since 2.0.0
+ */
+ @scala.annotation.varargs
+ def repartition(partitionExprs: Column*): Dataset[T] = {
+ repartitionByExpression(None, partitionExprs)
+ }
+
+ private def repartitionByRange(
+ numPartitions: Option[Int],
+ partitionExprs: Seq[Column]): Dataset[T] = {
+ require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
+ val sortExprs = partitionExprs.map {
+ case e if e.expr.hasSortOrder => e
+ case e => e.asc
+ }
+ buildRepartitionByExpression(numPartitions, sortExprs)
+ }
+
+ /**
+ * Returns a new Dataset partitioned by the given partitioning expressions into `numPartitions`.
+ * The resulting Dataset is range partitioned.
+ *
+ * At least one partition-by expression must be specified. When no explicit sort order is
+ * specified, "ascending nulls first" is assumed. Note, the rows are not sorted in each
+ * partition of the resulting Dataset.
+ *
+ * Note that due to performance reasons this method uses sampling to estimate the ranges. Hence,
+ * the output may not be consistent, since sampling can return different values. The sample size
+ * can be controlled by the config `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
+ repartitionByRange(Some(numPartitions), partitionExprs)
+ }
+
+ /**
+ * Returns a new Dataset partitioned by the given partitioning expressions, using
+ * `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is range
+ * partitioned.
+ *
+ * At least one partition-by expression must be specified. When no explicit sort order is
+ * specified, "ascending nulls first" is assumed. Note, the rows are not sorted in each
+ * partition of the resulting Dataset.
+ *
+ * Note that due to performance reasons this method uses sampling to estimate the ranges. Hence,
+ * the output may not be consistent, since sampling can return different values. The sample size
+ * can be controlled by the config `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def repartitionByRange(partitionExprs: Column*): Dataset[T] = {
+ repartitionByRange(None, partitionExprs)
+ }
+
+ /**
+ * Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions
+ * are requested. If a larger number of partitions is requested, it will stay at the current
+ * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in a
+ * narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a
+ * shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.
+ *
+ * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, this may result in
+ * your computation taking place on fewer nodes than you like (e.g. one node in the case of
+ * numPartitions = 1). To avoid this, you can call repartition. This will add a shuffle step,
+ * but means the current upstream partitions will be executed in parallel (per whatever the
+ * current partitioning is).
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def coalesce(numPartitions: Int): Dataset[T] = {
+ buildRepartition(numPartitions, shuffle = false)
+ }
+
+ /**
+ * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias
+ * for `dropDuplicates`.
+ *
+ * Note that for a streaming [[Dataset]], this method returns distinct rows only once regardless
+ * of the output mode, which the behavior may not be same with `DISTINCT` in SQL against
+ * streaming [[Dataset]].
+ *
+ * @note
+ * Equality checking is performed directly on the encoded representation of the data and thus
+ * is not affected by a custom `equals` function defined on `T`.
+ *
+ * @group typedrel
+ * @since 3.4.0
+ */
+ def distinct(): Dataset[T] = dropDuplicates()
+
+ /**
+ * Returns a best-effort snapshot of the files that compose this Dataset. This method simply
+ * asks each constituent BaseRelation for its respective files and takes the union of all
+ * results. Depending on the source relations, this may not find all input files. Duplicates are
+ * removed.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def inputFiles: Array[String] =
+ sparkSession
+ .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES)
+ .getInputFiles
+ .getFilesList
+ .asScala
+ .toArray
+
+ /**
+ * Interface for saving the content of the non-streaming Dataset out into external storage.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def write: DataFrameWriter[T] = {
+ new DataFrameWriter[T](this)
+ }
+
+ /**
+ * Create a write configuration builder for v2 sources.
+ *
+ * This builder is used to configure and execute write operations. For example, to append to an
+ * existing table, run:
+ *
+ * {{{
+ * df.writeTo("catalog.db.table").append()
+ * }}}
+ *
+ * This can also be used to create or replace existing tables:
+ *
+ * {{{
+ * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace()
+ * }}}
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def writeTo(table: String): DataFrameWriterV2[T] = {
+ new DataFrameWriterV2[T](table, this)
+ }
+
+ /**
+ * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`).
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def persist(): this.type = {
+ sparkSession.analyze { builder =>
+ builder.getPersistBuilder.setRelation(plan.getRoot)
+ }
+ this
+ }
+
+ /**
+ * Persist this Dataset with the given storage level.
+ *
+ * @param newLevel
+ * One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, `MEMORY_AND_DISK_SER`,
+ * `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc.
+ * @group basic
+ * @since 3.4.0
+ */
+ def persist(newLevel: StorageLevel): this.type = {
+ sparkSession.analyze { builder =>
+ builder.getPersistBuilder
+ .setRelation(plan.getRoot)
+ .setStorageLevel(StorageLevelProtoConverter.toConnectProtoType(newLevel))
+ }
+ this
+ }
+
+ /**
+ * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. This
+ * will not un-persist any cached data that is built upon this Dataset.
+ *
+ * @param blocking
+ * Whether to block until all blocks are deleted.
+ * @group basic
+ * @since 3.4.0
+ */
+ def unpersist(blocking: Boolean): this.type = {
+ sparkSession.analyze { builder =>
+ builder.getUnpersistBuilder
+ .setRelation(plan.getRoot)
+ .setBlocking(blocking)
+ }
+ this
+ }
+
+ /**
+ * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. This
+ * will not un-persist any cached data that is built upon this Dataset.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def unpersist(): this.type = unpersist(blocking = false)
+
+ /**
+ * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`).
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def cache(): this.type = persist()
+
+ /**
+ * Get the Dataset's current storage level, or StorageLevel.NONE if not persisted.
+ *
+ * @group basic
+ * @since 3.4.0
+ */
+ def storageLevel: StorageLevel = {
+ StorageLevelProtoConverter.toStorageLevel(
+ sparkSession
+ .analyze { builder =>
+ builder.getGetStorageLevelBuilder.setRelation(plan.getRoot)
+ }
+ .getGetStorageLevel
+ .getStorageLevel)
+ }
+
+ def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = {
+ throw new UnsupportedOperationException("withWatermark is not implemented.")
+ }
+
+ def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = {
+ throw new UnsupportedOperationException("observe is not implemented.")
+ }
+
+ def foreach(f: T => Unit): Unit = {
+ throw new UnsupportedOperationException("foreach is not implemented.")
+ }
+
+ def foreachPartition(f: Iterator[T] => Unit): Unit = {
+ throw new UnsupportedOperationException("foreach is not implemented.")
+ }
+
+ def checkpoint(): Dataset[T] = {
+ throw new UnsupportedOperationException("checkpoint is not implemented.")
+ }
+
+ def checkpoint(eager: Boolean): Dataset[T] = {
+ throw new UnsupportedOperationException("checkpoint is not implemented.")
+ }
+
+ def localCheckpoint(): Dataset[T] = {
+ throw new UnsupportedOperationException("localCheckpoint is not implemented.")
+ }
+
+ def localCheckpoint(eager: Boolean): Dataset[T] = {
+ throw new UnsupportedOperationException("localCheckpoint is not implemented.")
+ }
+
+ /**
+ * Returns `true` when the logical query plans inside both [[Dataset]]s are equal and therefore
+ * return same results.
+ *
+ * @note
+ * The equality comparison here is simplified by tolerating the cosmetic differences such as
+ * attribute names.
+ * @note
+ * This API can compare both [[Dataset]]s but can still return `false` on the [[Dataset]] that
+ * return the same results, for instance, from different plans. Such false negative semantic
+ * can be useful when caching as an example. This comparison may not be fast because it will
+ * execute a RPC call.
+ * @since 3.4.0
+ */
+ @DeveloperApi
+ def sameSemantics(other: Dataset[T]): Boolean = {
+ sparkSession.sameSemantics(this.plan, other.plan)
+ }
+
+ /**
+ * Returns a `hashCode` of the logical query plan against this [[Dataset]].
+ *
+ * @note
+ * Unlike the standard `hashCode`, the hash is calculated against the query plan simplified by
+ * tolerating the cosmetic differences such as attribute names.
+ * @since 3.4.0
+ */
+ @DeveloperApi
+ def semanticHash(): Int = {
+ sparkSession.semanticHash(this.plan)
+ }
+
+ def toJSON: Dataset[String] = {
+ select(to_json(struct(col("*")))).as(StringEncoder)
+ }
+
+ private[sql] def analyze: proto.AnalyzePlanResponse = {
+ sparkSession.analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA)
+ }
+
+ def collectResult(): SparkResult[T] = sparkSession.execute(plan, encoder)
+
+ private[sql] def withResult[E](f: SparkResult[T] => E): E = {
+ val result = collectResult()
+ try f(result)
+ finally {
+ result.close()
+ }
+ }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
new file mode 100644
index 0000000000000..66f591bf1fb99
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+/**
+ * A container for a [[Dataset]], used for implicit conversions in Scala.
+ *
+ * To use this, import implicit conversions in SQL:
+ * {{{
+ * val spark: SparkSession = ...
+ * import spark.implicits._
+ * }}}
+ *
+ * @since 3.4.0
+ */
+case class DatasetHolder[T] private[sql] (private val ds: Dataset[T]) {
+
+ // This is declared with parentheses to prevent the Scala compiler from treating
+ // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset.
+ def toDS(): Dataset[T] = ds
+
+ // This is declared with parentheses to prevent the Scala compiler from treating
+ // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
+ def toDF(): DataFrame = ds.toDF()
+
+ def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*)
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
new file mode 100644
index 0000000000000..5a10e1d52eb39
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -0,0 +1,417 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+
+/**
+ * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
+ * [[Dataset#cube cube]] or [[Dataset#rollup rollup]] (and also `pivot`).
+ *
+ * The main method is the `agg` function, which has multiple variants. This class also contains
+ * some first-order statistics such as `mean`, `sum` for convenience.
+ *
+ * @note
+ * This class was named `GroupedData` in Spark 1.x.
+ *
+ * @since 3.4.0
+ */
+class RelationalGroupedDataset private[sql] (
+ private[sql] val df: DataFrame,
+ private[sql] val groupingExprs: Seq[proto.Expression],
+ groupType: proto.Aggregate.GroupType,
+ pivot: Option[proto.Aggregate.Pivot] = None) {
+
+ private[this] def toDF(aggExprs: Seq[Column]): DataFrame = {
+ df.sparkSession.newDataFrame { builder =>
+ builder.getAggregateBuilder
+ .setInput(df.plan.getRoot)
+ .addAllGroupingExpressions(groupingExprs.asJava)
+ .addAllAggregateExpressions(aggExprs.map(e => e.expr).asJava)
+
+ groupType match {
+ case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
+ builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
+ case proto.Aggregate.GroupType.GROUP_TYPE_CUBE =>
+ builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
+ case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
+ builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+ case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT =>
+ assert(pivot.isDefined)
+ builder.getAggregateBuilder
+ .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT)
+ .setPivot(pivot.get)
+ case g => throw new UnsupportedOperationException(g.toString)
+ }
+ }
+ }
+
+ /**
+ * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The
+ * resulting `DataFrame` will also contain the grouping columns.
+ *
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(
+ * "age" -> "max",
+ * "expense" -> "sum"
+ * )
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
+ toDF((aggExpr +: aggExprs).map { case (colName, expr) =>
+ strToColumn(expr, df(colName))
+ })
+ }
+
+ /**
+ * (Scala-specific) Compute aggregates by specifying a map from column name to aggregate
+ * methods. The resulting `DataFrame` will also contain the grouping columns.
+ *
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(Map(
+ * "age" -> "max",
+ * "expense" -> "sum"
+ * ))
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def agg(exprs: Map[String, String]): DataFrame = {
+ toDF(exprs.map { case (colName, expr) =>
+ strToColumn(expr, df(colName))
+ }.toSeq)
+ }
+
+ /**
+ * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods.
+ * The resulting `DataFrame` will also contain the grouping columns.
+ *
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * import com.google.common.collect.ImmutableMap;
+ * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum"));
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ def agg(exprs: java.util.Map[String, String]): DataFrame = {
+ agg(exprs.asScala.toMap)
+ }
+
+ private[this] def strToColumn(expr: String, inputExpr: Column): Column = {
+ expr.toLowerCase(Locale.ROOT) match {
+ case "avg" | "average" | "mean" => functions.avg(inputExpr)
+ case "stddev" | "std" => functions.stddev(inputExpr)
+ case "count" | "size" => functions.count(inputExpr)
+ case name => Column.fn(name, inputExpr)
+ }
+ }
+
+ /**
+ * Compute aggregates by specifying a series of aggregate columns. Note that this function by
+ * default retains the grouping columns in its output. To not retain grouping columns, set
+ * `spark.sql.retainGroupColumns` to false.
+ *
+ * The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
+ *
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ *
+ * // Scala:
+ * import org.apache.spark.sql.functions._
+ * df.groupBy("department").agg(max("age"), sum("expense"))
+ *
+ * // Java:
+ * import static org.apache.spark.sql.functions.*;
+ * df.groupBy("department").agg(max("age"), sum("expense"));
+ * }}}
+ *
+ * Note that before Spark 1.4, the default behavior is to NOT retain grouping columns. To change
+ * to that behavior, set config variable `spark.sql.retainGroupColumns` to `false`.
+ * {{{
+ * // Scala, 1.3.x:
+ * df.groupBy("department").agg($"department", max("age"), sum("expense"))
+ *
+ * // Java, 1.3.x:
+ * df.groupBy("department").agg(col("department"), max("age"), sum("expense"));
+ * }}}
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def agg(expr: Column, exprs: Column*): DataFrame = {
+ toDF((expr +: exprs).map { case c =>
+ c
+ // TODO: deal with typed columns.
+ })
+ }
+
+ /**
+ * Count the number of rows for each group. The resulting `DataFrame` will also contain the
+ * grouping columns.
+ *
+ * @since 3.4.0
+ */
+ def count(): DataFrame = toDF(Seq(functions.count(functions.lit(1)).alias("count")))
+
+ /**
+ * Compute the average value for each numeric columns for each group. This is an alias for
+ * `avg`. The resulting `DataFrame` will also contain the grouping columns. When specified
+ * columns are given, only compute the average values for them.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def mean(colNames: String*): DataFrame = {
+ toDF(colNames.map(colName => functions.mean(colName)))
+ }
+
+ /**
+ * Compute the max value for each numeric columns for each group. The resulting `DataFrame` will
+ * also contain the grouping columns. When specified columns are given, only compute the max
+ * values for them.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def max(colNames: String*): DataFrame = {
+ toDF(colNames.map(colName => functions.max(colName)))
+ }
+
+ /**
+ * Compute the mean value for each numeric columns for each group. The resulting `DataFrame`
+ * will also contain the grouping columns. When specified columns are given, only compute the
+ * mean values for them.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def avg(colNames: String*): DataFrame = {
+ toDF(colNames.map(colName => functions.avg(colName)))
+ }
+
+ /**
+ * Compute the min value for each numeric column for each group. The resulting `DataFrame` will
+ * also contain the grouping columns. When specified columns are given, only compute the min
+ * values for them.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def min(colNames: String*): DataFrame = {
+ toDF(colNames.map(colName => functions.min(colName)))
+ }
+
+ /**
+ * Compute the sum for each numeric columns for each group. The resulting `DataFrame` will also
+ * contain the grouping columns. When specified columns are given, only compute the sum for
+ * them.
+ *
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def sum(colNames: String*): DataFrame = {
+ toDF(colNames.map(colName => functions.sum(colName)))
+ }
+
+ /**
+ * Pivots a column of the current `DataFrame` and performs the specified aggregation.
+ *
+ * There are two versions of `pivot` function: one that requires the caller to specify the list
+ * of distinct values to pivot on, and one that does not. The latter is more concise but less
+ * efficient, because Spark needs to first compute the list of distinct values internally.
+ *
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
+ *
+ * // Or without specifying column values (less efficient)
+ * df.groupBy("year").pivot("course").sum("earnings")
+ * }}}
+ *
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+ * aggregation.
+ *
+ * @param pivotColumn
+ * Name of the column to pivot.
+ * @since 3.4.0
+ */
+ def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn))
+
+ /**
+ * Pivots a column of the current `DataFrame` and performs the specified aggregation. There are
+ * two versions of pivot function: one that requires the caller to specify the list of distinct
+ * values to pivot on, and one that does not. The latter is more concise but less efficient,
+ * because Spark needs to first compute the list of distinct values internally.
+ *
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
+ *
+ * // Or without specifying column values (less efficient)
+ * df.groupBy("year").pivot("course").sum("earnings")
+ * }}}
+ *
+ * From Spark 3.0.0, values can be literal columns, for instance, struct. For pivoting by
+ * multiple columns, use the `struct` function to combine the columns and values:
+ *
+ * {{{
+ * df.groupBy("year")
+ * .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts"))))
+ * .agg(sum($"earnings"))
+ * }}}
+ *
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+ * aggregation.
+ *
+ * @param pivotColumn
+ * Name of the column to pivot.
+ * @param values
+ * List of values that will be translated to columns in the output DataFrame.
+ * @since 3.4.0
+ */
+ def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
+ pivot(Column(pivotColumn), values)
+ }
+
+ /**
+ * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
+ * aggregation.
+ *
+ * There are two versions of pivot function: one that requires the caller to specify the list of
+ * distinct values to pivot on, and one that does not. The latter is more concise but less
+ * efficient, because Spark needs to first compute the list of distinct values internally.
+ *
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", Arrays.